ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfa.seq2seq.gather_tree_from_array

Calculates the full beams for a TensorArray.

t A stacked TensorArray of size max_time that contains Tensors of shape [batch_size, beam_width, s] or [batch_size * beam_width, s] where s is the depth shape.
parent_ids The parent ids of shape [max_time, batch_size, beam_width].
sequence_length The sequence length of shape [batch_size, beam_width].

A Tensor which is a stacked TensorArray of the same size and type as t and where beams are sorted in each Tensor according to parent_ids.