Transformer decoder.

Like the encoder, the decoder is made up of N identical layers. Each layer is composed of the sublayers:

  1. Self-attention layer
  2. Multi-headed attention layer combining encoder outputs with results from the previous self-attention layer.
  3. Feedforward network (2 fully-connected layers)

num_layers Number of layers.
num_attention_heads Number of attention heads.
intermediate_size Size of the intermediate (Feedforward) layer.
activation Activation for the intermediate layer.
dropout_rate Dropout probability.
attention_dropout_rate Dropout probability for attention layers.
use_bias Whether to enable use_bias in attention layer. If set False, use_bias in attention layer is disabled.
norm_first Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is normalized.
norm_epsilon Epsilon value to initialize normalization layers.
intermediate_dropout Dropout probability for intermediate_dropout_layer.
**kwargs key word arguemnts passed to tf.keras.layers.Layer.



View source

Return the output of the decoder layer stacks.

target A tensor with shape (batch_size, target_length, hidden_size).
memory A tensor with shape (batch_size, input_length, hidden_size).
self_attention_mask A tensor with shape (batch_size, target_len, target_length), the mask for decoder self-attention layer.
cross_attention_mask A tensor with shape (batch_size, target_length, input_length) which is the mask for encoder-decoder attention layer.
cache (Used for fast decoding) A nested dictionary storing previous decoder self-attention values. The items are: {layer_n: {"k": A tensor with shape (batch_size, i, key_channels), "v": A tensor with shape (batch_size, i, value_channels)}, ...}
decode_loop_step An integer, the step number of the decoding loop. Used only for autoregressive inference on TPU.
return_all_decoder_outputs Return all decoder layer outputs. Note that the outputs are layer normed. This is useful when introducing per layer auxiliary loss.

Output of decoder. float32 tensor with shape (batch_size, target_length, hidden_size).