![]() |
A variant of efficient transformers which replaces softmax with kernels.
tfm.nlp.layers.KernelAttention(
feature_transform='exp',
num_random_features=256,
seed=0,
redraw=False,
is_short_seq=False,
begin_kernel=0,
scale=None,
scale_by_length=False,
use_causal_windowed=False,
causal_chunk_length=1,
causal_window_length=3,
causal_window_decay=None,
causal_padding=None,
**kwargs
)
This module combines ideas from the two following papers:
Rethinking Attention with Performers (https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu
- random/deterministic projection Chefs' Random Tables: Non-Trigonometric Random Features (https://arxiv.org/abs/2205.15317)
- expplus (OPRF mechanism)
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (https://arxiv.org/abs/2006.16236)
- elu
with the theory of approximating angular Performer kernels from go/performer.
The module enables computing efficient attention in both: long sequence and shorter sequence regimes. In the former setting, the attention matrix is never explicitly computed and instead its low-rank decomposition obtained with given kernel feature maps is leveraged to conduct attention module calculations (see: https://arxiv.org/abs/2006.16236). In the latter setting, attention matrix is constructed, but kernel features providing dimensionality reduction are applied, resulting in more efficient computation of the attention matrix.
Methods
call
call(
query, value, key=None, attention_mask=None, cache=None, training=False
)
Compute attention with kernel mechanism.
Args | |
---|---|
query
|
Query Tensor of shape [B, T, dim] .
|
value
|
Value Tensor of shape [B, S, dim] .
|
key
|
Optional key Tensor of shape [B, S, dim] . If not given, will use
value for both key and value , which is the most common case.
|
attention_mask
|
a boolean mask of shape [B, S] , that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
|
cache
|
Cache to accumulate history in memory. Used at inferecne time (streaming, decoding) for causal attention. |
training
|
Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing). |
Returns | |
---|---|
Multi-headed outputs of attention computation. |