Have a question? Connect with the community at the TensorFlow Forum Visit Forum

tfnlp.models.TransformerDecoder

Transformer decoder.

Like the encoder, the decoder is made up of N identical layers. Each layer is composed of the sublayers:

  1. Self-attention layer
  2. Multi-headed attention layer combining encoder outputs with results from the previous self-attention layer.
  3. Feedforward network (2 fully-connected layers)

num_layers Number of layers.
num_attention_heads Number of attention heads.
intermediate_size Size of the intermediate (Feedforward) layer.
activation Activation for the intermediate layer.
dropout_rate Dropout probability.
attention_dropout_rate Dropout probability for attention layers.
use_bias Whether to enable use_bias in attention layer. If set False, use_bias in attention layer is disabled.
norm_first Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is normalized.
norm_epsilon Epsilon value to initialize normalization layers.
intermediate_dropout Dropout probability for intermediate_dropout_layer.

Methods

call

View source

Return the output of the decoder layer stacks.

Args
target A tensor with shape [batch_size, target_length, hidden_size].
memory A tensor with shape [batch_size, input_length, hidden_size]
memory_mask A tensor with shape [batch_size, target_len, target_length], the mask for decoder self-attention layer.
target_mask A tensor with shape [batch_size, target_length, input_length] which is the mask for encoder-decoder attention layer.
cache (Used for fast decoding) A nested dictionary storing previous decoder self-attention values. The items are: {layer_n: {"k": A tensor with shape [batch_size, i, key_channels], "v": A tensor with shape [batch_size, i, value_channels]}, ...}
decode_loop_step An integer, the step number of the decoding loop. Used only for autoregressive inference on TPU.

Returns
Output of decoder. float32 tensor with shape [batch_size, target_length, hidden_size]