ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Additive attention layer, a.k.a. Bahdanau-style attention.

Inherits From: Layer, Module

Used in the notebooks

Used in the tutorials

Inputs are query tensor of shape [batch_size, Tq, dim], value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

  1. Reshape query and key into shapes [batch_size, Tq, 1, dim] and [batch_size, 1, Tv, dim] respectively.
  2. Calculate scores with shape [batch_size, Tq, Tv] as a non-linear sum: scores = tf.reduce_sum(tf.tanh(query + key), axis=-1)
  3. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).
  4. Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]: return tf.matmul(distribution, value).

use_scale If True, will create a variable to scale the attention scores.
causal Boolean. Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i. This prevents the flow of information from the future towards the past.
dropout Float between 0 and 1. Fraction of the units to drop for the attention scores.

Call Args:

  • inputs: List of the following tensors:
    • query: Query Tensor of shape [batch_size, Tq, dim].
    • value: Value Tensor of shape [batch_size, Tv, dim].
    • key: Optional key Tensor of shape [batch_size, Tv, dim]. If not given, will use value for both key and value, which is the most common case.
  • mask: List of the following tensors:
    • query_mask: A boolean mask Tensor of shape [batch_size, Tq]. If given, the output will be zero at the positions where mask==False.
    • value_mask: A boolean mask Tensor of shape [batch_size, Tv]. If given, will apply the mask such that values at positions where mask==False do not contribute to the result.
  • training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout).
  • return_attention_scores: bool, it True, returns the attention scores (after masking and softmax) as an additional output argument.