Have a question? Connect with the community at the TensorFlow Forum Visit Forum


Implements Talking-Heads Attention.

Inherits From: MultiHeadAttention

This is an implementation of Talking-Heads Attention based on the paper Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced multi-head attention by including linearprojections across the attention-heads dimension, immediately before and after the softmax operation.

See the base class MultiHeadAttention for more details.

num_heads Number of attention heads.
key_dim Size of each attention head for query and key.
value_dim Size of each attention head for value.
dropout Dropout probability.
use_bias Boolean, whether the dense layers use bias vectors/matrices.
output_shape The expected shape of an output tensor, besides the batch and sequence dims. If not specified, projects back to the key feature dim.
attention_axes axes over which the attention is applied. None means attention over all axes, but batch, heads, and features.
return_attention_scores bool, if True, returns the multi-head attention scores as an additional output argument.
kernel_initializer Initializer for dense layer kernels.
bias_initializer Initializer for dense layer biases.
kernel_regularizer Regularizer for dense layer kernels.
bias_regularizer Regularizer for dense layer biases.
activity_regularizer Regularizer for dense layer activity.
kernel_constraint Constraint for dense layer kernels.
bias_constraint Constraint for dense layer kernels.



This is where the layer's logic lives.

Note here that call() method in tf.keras is little bit different from keras API. In keras API, you can pass support masking for layers as additional arguments. Whereas tf.keras has compute_mask() method to support masking.

inputs Input tensor, or list/tuple of input tensors.
*args Additional positional arguments. Currently unused.
**kwargs Additional keyword arguments. Currently unused.

A tensor or list/tuple of tensors.