![]() |
MultiHeadAttention layer.
tfm.nlp.layers.ReuseMultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
reuse_attention=0,
use_relative_pe=False,
pe_max_seq_length=512,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
)
This is an implementation of multi-headed attention as described in the paper
"Attention is all you Need" (Vaswani et al., 2017).
If query
, key,
value
are the same, then
this is self-attention. Each timestep in query
attends to the
corresponding sequence in key
, and returns a fixed-width vector.
This layer first projects query
, key
and value
. These are
(effectively) a list of tensors of length num_attention_heads
, where the
corresponding shapes are (batch_size, <query dimensions>, key_dim)
,
(batch_size, <key/value dimensions>, key_dim)
,
(batch_size, <key/value dimensions>, value_dim)
.
Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.
Finally, the result tensor with the last dimension as value_dim can take an linear projection and return.
Examples:
Performs 1D cross-attention over two sequence inputs with an attention mask. Returns the additional attention weights over heads.
layer = MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[8, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
return_attention_scores=True)
print(output_tensor.shape)
(None, 8, 16)
print(weights.shape)
(None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
output_tensor = layer(input_tensor, input_tensor)
print(output_tensor.shape)
(None, 5, 3, 4, 16)
Args | |
---|---|
num_heads
|
Number of attention heads. |
key_dim
|
Size of each attention head for query and key. |
value_dim
|
Size of each attention head for value. |
dropout
|
Dropout probability. |
reuse_attention
|
An integer specifying number of heads to reuse. -1 for all heads. |
use_relative_pe
|
Whether to use relative position bias. |
max_sequence_length
|
Used to set the size of the relative positin encodings. |
use_bias
|
Boolean, whether the dense layers use bias vectors/matrices. |
output_shape
|
The expected shape of an output tensor, besides the batch and sequence dims. If not specified, projects back to the key feature dim. |
attention_axes
|
axes over which the attention is applied. None means
attention over all axes, but batch, heads, and features.
|
kernel_initializer
|
Initializer for dense layer kernels. |
bias_initializer
|
Initializer for dense layer biases. |
kernel_regularizer
|
Regularizer for dense layer kernels. |
bias_regularizer
|
Regularizer for dense layer biases. |
activity_regularizer
|
Regularizer for dense layer activity. |
kernel_constraint
|
Constraint for dense layer kernels. |
bias_constraint
|
Constraint for dense layer kernels. |
Call arguments:
query
: QueryTensor
of shape(B, T, dim)
.value
: ValueTensor
of shape(B, S, dim)
.key
: Optional keyTensor
of shape(B, S, dim)
. If not given, will usevalue
for bothkey
andvalue
, which is the most common case.attention_mask
: a boolean mask of shape(B, T, S)
, that prevents attention to certain positions. The boolean mask specifies which query elements can attend to which key elements, 1 indicates attention and 0 indicates no attention. Broadcasting can happen for the missing batch dimensions and the head dimension.return_attention_scores
: A boolean to indicate whether the output should be attention output if True, or (attention_output, attention_scores) if False. Defaults to False.training
: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). Defaults to either using the training mode of the parent layer/model, or False (inference) if there is no parent layer.
Returns | |
---|---|
attention_output
|
The result of the computation, of shape (B, T, E) ,
where T is for target sequence shapes and E is the query input last
dimension if output_shape is None . Otherwise, the multi-head outputs
are project to the shape specified by output_shape .
|
attention_scores
|
[Optional] multi-head attention coeffients over attention axes. |
Methods
call
call(
query,
value,
key=None,
attention_mask=None,
return_attention_scores=False,
training=None,
reuse_attention_scores=None
)
This is where the layer's logic lives.
The call()
method may not create state (except in its first invocation,
wrapping the creation of variables or other resources in tf.init_scope()
).
It is recommended to create state in __init__()
, or the build()
method
that is called automatically before call()
executes the first time.
Args | |
---|---|
inputs
|
Input tensor, or dict/list/tuple of input tensors.
The first positional inputs argument is subject to special rules:
|
*args
|
Additional positional arguments. May contain tensors, although this is not recommended, for the reasons above. |
**kwargs
|
Additional keyword arguments. May contain tensors, although
this is not recommended, for the reasons above.
The following optional keyword arguments are reserved:
training : Boolean scalar tensor of Python boolean indicating
whether the call is meant for training or inference.mask : Boolean input mask. If the layer's call() method takes a
mask argument, its default value will be set to the mask generated
for inputs by the previous layer (if input did come from a layer
that generated a corresponding mask, i.e. if it came from a Keras
layer with masking support).
|
Returns | |
---|---|
A tensor or list/tuple of tensors. |