Calculates the full beams for a TensorArray
.
tfa.seq2seq.gather_tree_from_array(
t: tfa.types.TensorLike
,
parent_ids: tfa.types.TensorLike
,
sequence_length: tfa.types.TensorLike
) -> tf.Tensor
Args |
t
|
A stacked TensorArray of size max_time that contains Tensor s of
shape [batch_size, beam_width, s] or [batch_size * beam_width, s]
where s is the depth shape.
|
parent_ids
|
The parent ids of shape [max_time, batch_size, beam_width] .
|
sequence_length
|
The sequence length of shape [batch_size, beam_width] .
|
Returns |
A Tensor which is a stacked TensorArray of the same size and type as
t and where beams are sorted in each Tensor according to
parent_ids .
|