TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfa.seq2seq.beam_search_decoder.get_attention_probs

View source on GitHub

Get attention probabilities from the cell state.

tfa.seq2seq.beam_search_decoder.get_attention_probs(
    next_cell_state,
    coverage_penalty_weight
)

Args:

  • next_cell_state: The next state from the cell, e.g. an instance of AttentionWrapperState if the cell is attentional.
  • coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0.

Returns:

The attention probabilities with shape [batch_size, beam_width, max_time] if coverage penalty is enabled. Otherwise, returns None.

Raises:

  • ValueError: If no cell is attentional but coverage penalty is enabled.