A variant of efficient transformers which replaces softmax with kernels.

This module combines ideas from the two following papers:

Rethinking Attention with Performers (

  • exp (Lemma 1, positive), relu
  • random/deterministic projection

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (

  • elu

with the theory of approximating angular Performer kernels from go/performer.

The module enables computing efficient attention in both: long sequence and shorter sequence regimes. In the former setting, the attention matrix is never explicitly computed and instead its low-rank decomposition obtained with given kernel feature maps is leveraged to conduct attention module calculations (see: In the latter setting, attention matrix is constructed, but kernel features providing dimensionality reduction are applied, resulting in more efficient computation of the attention matrix.

feature_transform A non-linear transform of the keys and quries. Possible transforms are "elu", "relu", "square", "exp", "expmod", "identity".
num_random_features Number of random features to be used for projection. if num_random_features <= 0, no production is used before transform.
seed The seed to begin drawing random features. Once the seed is set, the psedo number generation is determinisitc. Users should pass different seed for different layers. For multi-worker, each layer will use the same projection at each step.
redraw Whether to redraw projection every forward pass during training. The argument is only effective when num_random_features > 0.
is_short_seq boolean predicate indicating whether input data consists of very short sequences or not; in most cases this should be False (default option).
begin_kernel Apply kernel_attention after this sequence id and apply softmax attention before this.
scale The value to scale the dot product as described in Attention Is All You Need. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs The same arguments MultiHeadAttention layer.



View source

Compute attention with kernel mechanism.

query Query Tensor of shape [B, T, dim].
value Value Tensor of shape [B, S, dim].
key Optional key Tensor of shape [B, S, dim]. If not given, will use value for both key and value, which is the most common case.
attention_mask a boolean mask of shape [B, S], that prevents attenting to masked positions. Note that the mask is only appied to the keys. User may want to mask the output if query contains pads.
training Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing).

Multi-headed outputs of attention computation.