RNNCell with attention.
__init__( cell, attention_mechanism, attention_layer_size=None, alignment_history=False, cell_input_fn=None, output_attention=True, initial_cell_state=None, name=None, attention_layer=None, attention_fn=None )
NOTE If you are using the
BeamSearchDecoder with a cell wrapped in
AttentionWrapper, then you must ensure that:
- The encoder output has been tiled to
batch_sizeargument passed to the
zero_statemethod of this wrapper is equal to
true_batch_size * beam_width.
- The initial state created with
zero_stateabove contains a
cell_statevalue containing properly tiled final state from the encoder.
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) attention_mechanism = MyFavoriteAttentionMechanism( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) attention_cell = AttentionWrapper(cell, attention_mechanism, ...) decoder_initial_state = attention_cell.zero_state( dtype, batch_size=true_batch_size * beam_width) decoder_initial_state = decoder_initial_state.clone( cell_state=tiled_encoder_final_state)
cell: An instance of
attention_mechanism: A list of
AttentionMechanisminstances or a single instance.
attention_layer_size: A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If attention_layer is set, this must be None. If attention_fn is set, it must guaranteed that the outputs of attention_fn also meet the above requirements.
alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major
TensorArrayon which you must call
cell_input_fn: (optional) A
callable. The default is:
lambda inputs, attention: array_ops.concat([inputs, attention], -1).
output_attention: Python bool. If
True(default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If
False, the output at each time step is the output of
cell. This is the behavior of Bhadanau-style attention mechanisms. In both cases, the
attentiontensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output.
initial_cell_state: The initial state value to use for the cell when the user calls
zero_state(). Note that if this value is provided now, and the user uses a
zero_statewhich does not match the batch size of
initial_cell_state, proper behavior is not guaranteed.
name: Name to use when creating ops.
attention_layer: A list of
tf.compat.v1.layers.Layerinstances or a single
tf.compat.v1.layers.Layerinstance taking the context and cell output as inputs to generate attention at each time step. If None (default), use the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None.
attention_fn: An optional callable function that allows users to provide their own customized attention function, which takes input (attention_mechanism, cell_output, attention_state, attention_layer) and outputs (attention, alignments, next_attention_state). If provided, the attention_layer_size should be the size of the outputs of attention_fn.
attention_layer_sizeis not None and (
attention_mechanismis a list but
attention_layer_sizeis not; or vice versa).
attention_layer_sizeis not None,
attention_mechanismis a list, and its length does not match that of
attention_layerare set simultaneously.
state_size property of
AttentionWrapperState tuple containing shapes used by this object.
get_initial_state( inputs=None, batch_size=None, dtype=None )
zero_state( batch_size, dtype )
Return an initial (zero) state tuple for this
NOTE Please see the initializer documentation for details of how
zero_state if using an
AttentionWrapper with a
0Dinteger tensor: the batch size.
dtype: The internal state data type.
AttentionWrapperState tuple containing zeroed out tensors and,
ValueError: (or, possibly at runtime, InvalidArgument), if
batch_sizedoes not match the output size of the encoder passed to the wrapper object at initialization time.