tfa.seq2seq.gather_tree_from_array

View source on GitHub

Calculates the full beams for TensorArrays.

Aliases:

tfa.seq2seq.gather_tree_from_array(
    t,
    parent_ids,
    sequence_length
)

Args:

  • t: A stacked TensorArray of size max_time that contains Tensors of shape [batch_size, beam_width, s] or [batch_size * beam_width, s] where s is the depth shape.
  • parent_ids: The parent ids of shape [max_time, batch_size, beam_width].
  • sequence_length: The sequence length of shape [batch_size, beam_width].

Returns:

A Tensor which is a stacked TensorArray of the same size and type as t and where beams are sorted in each Tensor according to parent_ids.