tfa.seq2seq.beam_search_decoder.BeamSearchDecoderMixin

View source on GitHub

Class BeamSearchDecoderMixin

BeamSearchDecoderMixin contains the common methods for

BeamSearchDecoder.

It is expected to be used a base class for concrete BeamSearchDecoder. Since this is a mixin class, it is expected to be used together with other class as base.

__init__

View source

__init__(
    cell,
    beam_width,
    output_layer=None,
    length_penalty_weight=0.0,
    coverage_penalty_weight=0.0,
    reorder_tensor_arrays=True,
    **kwargs
)

Initialize the BeamSearchDecoderMixin.

Args:

  • cell: An RNNCell instance.
  • beam_width: Python integer, the number of beams.
  • output_layer: (Optional) An instance of tf.keras.layers.Layer, i.e., tf.keras.layers.Dense. Optional layer to apply to the RNN output prior to storing the result or sampling.
  • length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
  • coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0.
  • reorder_tensor_arrays: If True, TensorArrays' elements within the cell state will be reordered according to the beam search path. If the TensorArray can be reordered, the stacked form will be returned. Otherwise, the TensorArray will be returned as is. Set this flag to False if the cell state contains TensorArrays that are not amenable to reordering.
  • **kwargs: Dict, other keyword arguments for parent class.

Raises:

  • TypeError: if cell is not an instance of RNNCell, or output_layer is not an instance of tf.keras.layers.Layer.

Properties

batch_size

output_size

tracks_own_finished

The BeamSearchDecoder shuffles its beams and their finished state.

For this reason, it conflicts with the dynamic_decode function's tracking of finished states. Setting this property to true avoids early stopping of decoding due to mismanagement of the finished state in dynamic_decode.

Returns:

True.

Methods

finalize

View source

finalize(
    outputs,
    final_state,
    sequence_lengths
)

Finalize and return the predicted_ids.

Args:

  • outputs: An instance of BeamSearchDecoderOutput.
  • final_state: An instance of BeamSearchDecoderState. Passed through to the output.
  • sequence_lengths: An int64 tensor shaped [batch_size, beam_width]. The sequence lengths determined for each beam during decode. NOTE These are ignored; the updated sequence lengths are stored in final_state.lengths.

Returns:

  • outputs: An instance of FinalBeamSearchDecoderOutput where the predicted_ids are the result of calling _gather_tree.
  • final_state: The same input instance of BeamSearchDecoderState.

step

View source

step(
    time,
    inputs,
    state,
    training=None,
    name=None
)

Perform a decoding step.

Args:

  • time: scalar int32 tensor.
  • inputs: A (structure of) input tensors.
  • state: A (structure of) state tensors and TensorArrays.
  • training: Python boolean. Indicates whether the layer should behave in training mode or in inference mode. Only relevant when dropout or recurrent_dropout is used.
  • name: Name scope for any created operations.

Returns:

(outputs, next_state, next_inputs, finished).