|View source on GitHub|
namedtuple storing the state of a
tfa.seq2seq.AttentionWrapperState( cell_state, attention, alignments, alignment_history, attention_state )
cell_state: The state of the wrapped
RNNCellat the previous time step.
attention: The attention emitted at the previous time step.
alignments: A single or tuple of
Tensor(s) containing the alignments emitted at the previous time step for each attention mechanism.
alignment_history: (if enabled) a single or tuple of
TensorArray(s) containing alignment matrices from all time steps for each attention mechanism. Call
stack()on each to convert to a
attention_state: A single or tuple of nested objects containing attention mechanism state for each attention mechanism. The objects may contain Tensors or TensorArrays.
clone( **kwargs )
Clone this object, overriding components provided by kwargs.
The new state fields' shape must match original state fields' shape. This will be validated, and original fields' shape will be propagated to new fields.
initial_state = attention_wrapper.get_initial_state( batch_size=..., dtype=...) initial_state = initial_state.clone(cell_state=encoder_state)
Any properties of the state object to replace in the