tfa.seq2seq.BeamSearchDecoderState

View source on GitHub

State of a BeamSearchDecoder.

Contains:

  • cell_state: The cell state returned at the previous time step.
  • log_probs: The accumulated log probabilities of each beam. A float32 Tensor of shape [batch_size, beam_width].
  • finished: The finished status of each beam. A bool Tensor of shape [batch_size, beam_width].
  • lengths: The accumulated length of each beam. An int64 Tensor of shape [batch_size, beam_width].
  • accumulated_attention_prob: Accumulation of the attention probabilities (used to compute the coverage penalty)

cell_state

log_probs

finished

lengths

accumulated_attention_probs