Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

tfm.nlp.layers.TransformerXL

Transformer XL.

This layer combines multiple Transformer XL blocks from "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860).

This layer handles the attention biases as well as memory caching and reuse as in Transformer XL and XLNet.

vocab_size The number of tokens in vocabulary.
num_layers The number of layers.
hidden_size The hidden size.
num_attention_heads The number of attention heads.
head_size The dimension size of each attention head.
inner_size The hidden size in feed-forward layers.
dropout_rate Dropout rate used in each Transformer XL block.
attention_dropout_rate Dropout rate on attention probabilities.
two_stream Whether or not to use TwoStreamRelativeAttention used in the XLNet pretrainer. If False, then it will use MultiHeadRelativeAttention as in Transformer XL.
initializer The initializer to use for attention biases.
tie_attention_biases Whether or not to tie biases together. If True, then each Transformer XL block shares the same trainable attention bias. If False, then each block has its own attention bias. This is usually set to True.
memory_length The number of tokens to cache.
reuse_length The number of tokens in the current batch to be cached and reused in the future.
inner_activation The activation to use in the inner layers for Transformer XL blocks. Typically "relu" or "gelu".

Methods

call

View source

Implements call() for the layer.

Args
content_stream Tensor, the input content stream. This is the standard input to Transformer XL and is commonly referred to as h in XLNet.
relative_position_encoding Relative positional encoding Tensor of shape [B, L, dim].
segment_matrix Optional Tensor of shape [B, S, S + M]. Used in XLNet, but not in Transformer XL.
segment_embedding Optional Tensor of shape [2, num_heads, dim]. Used in XLNet, but not in Transformer XL.
state Optional Tensor of shape [B, M, E], where M is the length of the state or memory. If passed, this is also attended over as in Transformer XL.
content_attention_mask Optional Tensor representing the mask that is added to content attention logits. If state is not None, the mask source sequence dimension should extend M.
query_stream Optional Tensor, the query stream. This is introduced in TwoStreamRelativeAttention/XLNet pretrainer. This is ignored if two_stream is False.
query_attention_mask Optional Tensor representing the mask that is added to query attention logits. If state is not None, the mask source sequence dimension should extend M.
target_mapping Optional Tensor representing the target mapping when calculating query attention.

Returns
A tuple consisting of the attention output and the list of cached memory states. The attention output is content_attention if two_stream is False, otherwise it is query_attention.