|View source on GitHub|
State of a
tfa.seq2seq.AttentionWrapperState( cell_state, attention, alignments, alignment_history, attention_state )
clone( **kwargs )
Clone this object, overriding components provided by kwargs.
The new state fields' shape must match original state fields' shape. This will be validated, and original fields' shape will be propagated to new fields.
batch_size = 1
memory = tf.random.normal(shape=[batch_size, 3, 100])
encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))]
attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length= * batch_size)
attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10)
decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
Any properties of the state object to replace in the