Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

Module: tfa.seq2seq.beam_search_decoder

View source on GitHub

A decoder that performs beam search.


class BeamSearchDecoder: BeamSearch sampling decoder.

class BeamSearchDecoderMixin: BeamSearchDecoderMixin contains the common methods for

class BeamSearchDecoderOutput: BeamSearchDecoderOutput(scores, predicted_ids, parent_ids)

class BeamSearchDecoderState: BeamSearchDecoderState(cell_state, log_probs, finished, lengths, accumulated_attention_probs)

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is


attention_probs_from_attn_state(...): Calculates the average attention probabilities.

gather_tree_from_array(...): Calculates the full beams for TensorArrays.

get_attention_probs(...): Get attention probabilities from the cell state.

tile_batch(...): Tile the batch dimension of a (possibly nested structure of) tensor(s)