Module: tfa.seq2seq.beam_search_decoder

View source on GitHub

A decoder that performs beam search.


class BeamSearchDecoder: BeamSearch sampling decoder.

class BeamSearchDecoderMixin: BeamSearchDecoderMixin contains the common methods for

class BeamSearchDecoderOutput: Outputs of a BeamSearchDecoder step.

class BeamSearchDecoderState: State of a BeamSearchDecoder.

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is


attention_probs_from_attn_state(...): Calculates the average attention probabilities.

gather_tree(...): Calculates the full beams from the per-step ids and parent beam ids.

gather_tree_from_array(...): Calculates the full beams for TensorArrays.

get_attention_probs(...): Get attention probabilities from the cell state.

tile_batch(...): Tile the batch dimension of a (possibly nested structure of) tensor(s)

Type Aliases

FloatTensorLike: The central part of internal API.

Number: The central part of internal API.

TensorLike: The central part of internal API.