ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Dot-product attention layer, a.k.a. Luong-style attention.

Inherits From: Layer, Module

Inputs are query tensor of shape [batch_size, Tq, dim], value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

  1. Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).
  2. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).
  3. Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]: return tf.matmul(distribution, value).

use_scale If True, will create a scalar variable to scale the attention scores.
causal Boolean. Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i. This prevents the flow of information from the future towards the past.
dropout Float between 0 and 1. Fraction of the units to drop for the attention scores.

Call Args:

  • inputs: List of the following tensors:
    • query: Query Tensor of shape [batch_size, Tq, dim].
    • value: Value Tensor of shape [batch_size, Tv, dim].
    • key: Optional key Tensor of shape [batch_size, Tv, dim]. If not given, will use value for both key and value, which is the most common case.
  • mask: List of the following tensors:
    • query_mask: A boolean mask Tensor of shape [batch_size, Tq]. If given, the output will be zero at the positions where mask==False.
    • value_mask: A boolean mask Tensor of shape [batch_size, Tv]. If given, will apply the mask such that values at positions where mask==False do not contribute to the result.
  • return_attention_scores: bool, it True, returns the attention scores (after masking and softmax) as an additional output argument.
  • training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout).


Attention outputs of shape [batch_size, Tq, dim]. [Optional] Attention scores after masking and softmax with shape [batch_size, Tq, Tv].

The meaning of query, value and key depend on the application. In the case of text similarity, for example, query is the sequence embeddings of the first piece of text and value is the sequence embeddings of the second piece of text.