TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Module: tfa.seq2seq.beam_search_decoder

View source on GitHub

A decoder that performs beam search.

Classes

class BeamSearchDecoder: BeamSearch sampling decoder.

class BeamSearchDecoderMixin: BeamSearchDecoderMixin contains the common methods for

class BeamSearchDecoderOutput: BeamSearchDecoderOutput(scores, predicted_ids, parent_ids)

class BeamSearchDecoderState: BeamSearchDecoderState(cell_state, log_probs, finished, lengths, accumulated_attention_probs)

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is

Functions

attention_probs_from_attn_state(...): Calculates the average attention probabilities.

gather_tree_from_array(...): Calculates the full beams for TensorArrays.

get_attention_probs(...): Get attention probabilities from the cell state.

tile_batch(...): Tile the batch dimension of a (possibly nested structure of) tensor(s)