tfa.seq2seq.gather_tree_from_array

View source on GitHub

Calculates the full beams for TensorArrays.

t A stacked TensorArray of size max_time that contains Tensors of shape [batch_size, beam_width, s] or [batch_size * beam_width, s] where s is the depth shape.
parent_ids The parent ids of shape [max_time, batch_size, beam_width].
sequence_length The sequence length of shape [batch_size, beam_width].

A Tensor which is a stacked TensorArray of the same size and type as t and where beams are sorted in each Tensor according to parent_ids.