tfa.seq2seq.Decoder

An RNN Decoder abstract interface object.

Concepts used by this interface:

  • inputs: (structure of) tensors and TensorArrays that is passed as input to the RNN cell composing the decoder, at each time step.
  • state: (structure of) tensors and TensorArrays that is passed to the RNN cell instance as the state.
  • finished: boolean tensor telling whether each sequence in the batch is finished.
  • training: boolean whether it should behave in training mode or in inference mode.
  • outputs: instance of tfa.seq2seq.BasicDecoderOutput. Result of the decoding, at each time step.

batch_size The batch size of input values.
output_dtype A (possibly nested tuple of...) dtype[s].
output_size A (possibly nested tuple of...) integer[s] or TensorShape object[s].
tracks_own_finished Describes whether the Decoder keeps track of finished states.

Most decoders will emit a true/false finished value independently at each time step. In this case, the tfa.seq2seq.dynamic_decode function keeps track of which batch entries are already finished, and performs a logical OR to insert new batches to the finished set.

Some decoders, however, shuffle batches / beams between time steps and tfa.seq2seq.dynamic_decode will mix up the finished state across these entries because it does not track the reshuffle across time steps. In this case, it is up to the decoder to declare that it will keep track of its own finished state by setting this property to True.

Methods

finalize

View source

initialize

View source

Called before any decoding iterations.

This methods must compute initial input values and initial state.

Args
name Name scope for any created operations.

Returns
(finished, initial_inputs, initial_state): initial values of 'finished' flags, inputs and state.

step

View source

Called per step of decoding (but only once for dynamic decoding).

Args
time Scalar int32 tensor. Current step number.
inputs RNN cell input (possibly nested tuple of) tensor[s] for this time step.
state RNN cell state (possibly nested tuple of) tensor[s] from previous time step.
training Python boolean. Indicates whether the layer should behave in training mode or in inference mode. Only relevant when dropout or recurrent_dropout is used.
name Name scope for any created operations.

Returns
(outputs, next_state, next_inputs, finished): outputs is an object containing the decoder output, next_state is a (structure of) state tensors and TensorArrays, next_inputs is the tensor that should be used as input for the next step, finished is a boolean tensor telling whether the sequence is complete, for each sequence in the batch.