![]() |
Transformer XL block.
tfm.nlp.layers.TransformerXLBlock(
vocab_size,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
two_stream=False,
norm_epsilon=1e-12,
inner_activation='relu',
kernel_initializer='variance_scaling',
inner_dropout=0.0,
**kwargs
)
This implements a Transformer XL block from "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860).
This block is further extended to allow for the Transformer-XL re-parameterization in "XLNet: Generalized Autoregressive Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Given an input stream, this block computes attention, applies dropouts and layer norms and feeds into the FFN network.
Methods
call
call(
content_stream,
content_attention_bias,
positional_attention_bias,
relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None
)
Implements call
for the Layer.
Args | |
---|---|
content_stream
|
Tensor , the input content stream. This is the standard
input to Transformer XL and is commonly referred to as h in XLNet.
|
content_attention_bias
|
Bias Tensor for content based attention of shape
[num_heads, dim] .
|
positional_attention_bias
|
Bias Tensor for position based attention of
shape [num_heads, dim] .
|
relative_position_encoding
|
Relative positional encoding Tensor of shape
[B, L, dim] .
|
segment_matrix
|
Optional Tensor of shape [B, S, S + M] . Used in XLNet,
but not in Transformer XL.
|
segment_encoding
|
Optional Tensor of shape [2, num_heads, dim] . Used
in XLNet, but not in Transformer XL.
|
segment_attention_bias
|
Optional bias Tensor for segment based attention
of shape [num_heads, dim] .
|
state
|
Optional Tensor of shape [B, M, E] , where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
|
content_attention_mask
|
Optional Tensor representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
|
query_stream
|
Optional Tensor , the query stream. This is introduced in
TwoStreamRelativeAttention /XLNet pretrainer. This is ignored if
two_stream is False .
|
query_attention_mask
|
Optional Tensor representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
|
target_mapping
|
Optional Tensor representing the target mapping when
calculating query attention.
|
Returns | |
---|---|
A dict object, containing the key value pairs for content_attention
and (if two_stream is True ) query_attention .
|