A variant of efficient transformers which replaces softmax with kernels.

This module combines ideas from the two following papers:

Rethinking Attention with Performers (

  • exp (Lemma 1, positive), relu
  • random/deterministic projection Chefs' Random Tables: Non-Trigonometric Random Features (
  • expplus (OPRF mechanism)

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (

  • elu

with the theory of approximating angular Performer kernels from go/performer.

The module enables computing efficient attention in both: long sequence and shorter sequence regimes. In the former setting, the attention matrix is never explicitly computed and instead its low-rank decomposition obtained with given kernel feature maps is leveraged to conduct attention module calculations (see: In the latter setting, attention matrix is constructed, but kernel features providing dimensionality reduction are applied, resulting in more efficient computation of the attention matrix.

feature_transform A non-linear transform of the keys and queries. Possible transforms are "elu", "relu", "square", "exp", "expplus", "expmod", "identity".
num_random_features Number of random features to be used for projection. if num_random_features <= 0, no production is used before transform.
seed The seed to begin drawing random features. Once the seed is set, the psedo number generation is determinisitc. Users should pass different seed for different layers. For multi-worker, each layer will use the same projection at each step.
redraw Whether to redraw projection every forward pass during training. The argument is only effective when num_random_features > 0.
is_short_seq boolean predicate indicating whether input data consists of very short sequences or not; in most cases this should be False (default option).
begin_kernel Apply kernel_attention after this sequence id and apply softmax attention before this.
scale The value to scale the dot product as described in Attention Is All You Need. If None, we use 1/sqrt(dk) as described in the paper.
scale_by_length boolean predicate indicating whether additionally scale the dot product based on key length. Set as log_512^(n) to stablize attention entropy against length. Refer to for details.
use_causal_windowed If true perform windowed causal attention. See causal_windowed_performer_attention function docstring for more details.
causal_chunk_length Length of each chunk in tokens.
causal_window_length Length of attention window in chunks.
causal_window_decay Float window decay factor or None. If set, exponentially decay past attention window values by this factor before summation.
causal_padding Pad the query, value and key input tensors across the axis from either left or right if padding is set to "left" or "right"; apply no padding if padding is set to None. In the latter case, the axis dimension of the query, value and key input tensors must be divisible by the chunk_length.
**kwargs The same arguments MultiHeadAttention layer.



View source

Compute attention with kernel mechanism.

query Query Tensor of shape [B, T, dim].
value Value Tensor of shape [B, S, dim].
key Optional key Tensor of shape [B, S, dim]. If not given, will use value for both key and value, which is the most common case.
attention_mask a boolean mask of shape [B, S], that prevents attenting to masked positions. Note that the mask is only appied to the keys. User may want to mask the output if query contains pads.
cache Cache to accumulate history in memory. Used at inferecne time (streaming, decoding) for causal attention.
training Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing).

Multi-headed outputs of attention computation.