![]() |
Transformer XL.
tfm.nlp.layers.TransformerXL(
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
reuse_length=None,
inner_activation='relu',
**kwargs
)
This layer combines multiple Transformer XL blocks from "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860).
This layer handles the attention biases as well as memory caching and reuse as in Transformer XL and XLNet.
Methods
call
call(
content_stream,
relative_position_encoding,
segment_matrix=None,
segment_embedding=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None
)
Implements call() for the layer.
Args | |
---|---|
content_stream
|
Tensor , the input content stream. This is the standard
input to Transformer XL and is commonly referred to as h in XLNet.
|
relative_position_encoding
|
Relative positional encoding Tensor of shape
[B, L, dim] .
|
segment_matrix
|
Optional Tensor of shape [B, S, S + M] . Used in XLNet,
but not in Transformer XL.
|
segment_embedding
|
Optional Tensor of shape [2, num_heads, dim] . Used
in XLNet, but not in Transformer XL.
|
state
|
Optional Tensor of shape [B, M, E] , where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
|
content_attention_mask
|
Optional Tensor representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
|
query_stream
|
Optional Tensor , the query stream. This is introduced in
TwoStreamRelativeAttention /XLNet pretrainer. This is ignored if
two_stream is False .
|
query_attention_mask
|
Optional Tensor representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
|
target_mapping
|
Optional Tensor representing the target mapping when
calculating query attention.
|
Returns | |
---|---|
A tuple consisting of the attention output and the list of cached memory
states.
The attention output is content_attention if two_stream is False ,
otherwise it is query_attention .
|