tfm.nlp.layers.TransformerXLBlock

Transformer XL block.

This implements a Transformer XL block from "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860).

This block is further extended to allow for the Transformer-XL re-parameterization in "XLNet: Generalized Autoregressive Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).

Given an input stream, this block computes attention, applies dropouts and layer norms and feeds into the FFN network.

vocab_size The size of the token vocabulary.
hidden_size The size of the transformer hidden layers.
num_attention_heads The number of attention heads.
head_size The dimension size of each attention head.
inner_size The inner size for the transformer layers.
dropout_rate Dropout rate for the output of this layer.
attention_dropout_rate Dropout rate on attention probabilities.
two_stream Whether or not to use TwoStreamRelativeAttention used in the XLNet pretrainer. If False, then it will use MultiHeadRelativeAttention as in Transformer XL.
norm_epsilon Epsilon value to initialize normalization layers.
inner_activation The activation to use for the inner FFN layers.
kernel_initializer Initializer for dense layer kernels.
inner_dropout Dropout probability for the inner dropout layer.

Methods

call

View source

Implements call for the Layer.

Args
content_stream Tensor, the input content stream. This is the standard input to Transformer XL and is commonly referred to as h in XLNet.
content_attention_bias Bias Tensor for content based attention of shape [num_heads, dim].
positional_attention_bias Bias Tensor for position based attention of shape [num_heads, dim].
relative_position_encoding Relative positional encoding Tensor of shape [B, L, dim].
segment_matrix Optional Tensor of shape [B, S, S + M]. Used in XLNet, but not in Transformer XL.
segment_encoding Optional Tensor of shape [2, num_heads, dim]. Used in XLNet, but not in Transformer XL.
segment_attention_bias Optional bias Tensor for segment based attention of shape [num_heads, dim].
state Optional Tensor of shape [B, M, E], where M is the length of the state or memory. If passed, this is also attended over as in Transformer XL.
content_attention_mask Optional Tensor representing the mask that is added to content attention logits. If state is not None, the mask source sequence dimension should extend M.
query_stream Optional Tensor, the query stream. This is introduced in TwoStreamRelativeAttention/XLNet pretrainer. This is ignored if two_stream is False.
query_attention_mask Optional Tensor representing the mask that is added to query attention logits. If state is not None, the mask source sequence dimension should extend M.
target_mapping Optional Tensor representing the target mapping when calculating query attention.

Returns
A dict object, containing the key value pairs for content_attention and (if two_stream is True) query_attention.