Have a question? Connect with the community at the TensorFlow Forum

# tfnlp.layers.MultiHeadAttention

MultiHeadAttention layer.

This is an implementation of multi-headed attention based on "Attention is all you Need". If `query`, `key,` `value` are the same, then this is self-attention. Each timestep in `query` attends to the corresponding sequence in `key`, and returns a fixed-width vector.

This layer first projects `query`, `key` and `value`. These are (effectively) a list of tensors of length `num_attention_heads`, where the corresponding shapes are [batch_size, , key_dim], [batch_size, , key_dim], [batch_size, , value_dim].

Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor.

Finally, the result tensor with the last dimension as value_dim can take an linear projection and return.

#### Examples:

Performs 1D cross-attention over two sequence inputs with an attention mask. Returns the additional attention weights over heads.

````layer = MultiHeadAttention(num_heads=2, key_dim=2)`
`target = tf.keras.Input(shape=[8, 16])`
`source = tf.keras.Input(shape=[4, 16])`
`output_tensor, weights = layer(target, source,`
`                               return_attention_scores=True)`
`print(output_tensor.shape)`
`(None, 8, 16)`
`print(weights.shape)`
`(None, 2, 8, 4)`
```

Performs 2D self-attention over a 5D input tensor on axes 2 and 3.

````layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))`
`input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])`
`output_tensor = layer(input_tensor, input_tensor)`
`print(output_tensor.shape)`
`(None, 5, 3, 4, 16)`
```

`num_heads` Number of attention heads.
`key_dim` Size of each attention head for query and key.
`value_dim` Size of each attention head for value.
`dropout` Dropout probability.
`use_bias` Boolean, whether the dense layers use bias vectors/matrices.
`output_shape` The expected shape of an output tensor, besides the batch and sequence dims. If not specified, projects back to the key feature dim.
`attention_axes` axes over which the attention is applied. `None` means attention over all axes, but batch, heads, and features.
`kernel_initializer` Initializer for dense layer kernels.
`bias_initializer` Initializer for dense layer biases.
`kernel_regularizer` Regularizer for dense layer kernels.
`bias_regularizer` Regularizer for dense layer biases.
`activity_regularizer` Regularizer for dense layer activity.
`kernel_constraint` Constraint for dense layer kernels.
`bias_constraint` Constraint for dense layer kernels.

#### Call arguments:

• `query`: Query `Tensor` of shape `[B, T, dim]`.
• `value`: Value `Tensor` of shape `[B, S, dim]`.
• `key`: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use `value` for both `key` and `value`, which is the most common case.
• `attention_mask`: a boolean mask of shape `[B, T, S]`, that prevents attention to certain positions. The boolean mask specifies which query elements can attend to which key elements, 1 indicates attention and 0 indicates no attention. Broadcasting can happen for the missing batch dimensions and the head dimension.
• `return_attention_scores`: A boolean to indicate whether the output should be attention output if True, or (attention_output, attention_scores) if False. Defaults to False.
• `training`: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). Defaults to either using the training mode of the parent layer/model, or False (inference) if there is no parent layer.

`attention_output` The result of the computation, of shape [B, T, E], where `T` is for target sequence shapes and `E` is the query input last dimension if `output_shape` is `None`. Otherwise, the multi-head outputs are project to the shape specified by `output_shape`.
`attention_scores` [Optional] multi-head attention coeffients over attention axes.

## Methods

### `call`

This is where the layer's logic lives.

Note here that `call()` method in `tf.keras` is little bit different from `keras` API. In `keras` API, you can pass support masking for layers as additional arguments. Whereas `tf.keras` has `compute_mask()` method to support masking.

Args
`inputs` Input tensor, or list/tuple of input tensors.
`*args` Additional positional arguments. Currently unused.
`**kwargs` Additional keyword arguments. Currently unused.

Returns
A tensor or list/tuple of tensors.

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]