|View source on GitHub|
Attention layer with cache used for autoregressive decoding.
tfm.nlp.layers.CachedAttention( num_heads, key_dim, value_dim=None, dropout=0.0, use_bias=True, output_shape=None, attention_axes=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs )
Arguments are the same as
call( query, value, key=None, attention_mask=None, cache=None, decode_loop_step=None, return_attention_scores=False )
This is where the layer's logic lives.
call() method may not create state (except in its first invocation,
wrapping the creation of variables or other resources in
It is recommended to create state in
__init__(), or the
that is called automatically before
call() executes the first time.
Input tensor, or dict/list/tuple of input tensors.
The first positional
||Additional positional arguments. May contain tensors, although this is not recommended, for the reasons above.|
Additional keyword arguments. May contain tensors, although
this is not recommended, for the reasons above.
The following optional keyword arguments are reserved:
|A tensor or list/tuple of tensors.|