|View source on GitHub|
Multi-channel Attention layer.
tfnlp.layers.MultiChannelAttention( num_heads, key_dim, value_dim=None, dropout=0.0, use_bias=True, output_shape=None, attention_axes=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs )
Introduced in, Generating Representative Headlines for News Stories . Expects multiple cross-attention target sequences.
[B, T, dim].
[B, A, S, dim], where A denotes the
context_attention_weights: Context weights of shape
[B, N, T, A], where N is the number of attention heads. Combines multi-channel sources context tensors according to the distribution among channels.
key: Optional key
[B, A, S, dim]. If not given, will use
value, which is the most common case.
attention_mask: a boolean mask of shape
[B, T, S], that prevents attention to certain positions.
call( query, value, key=None, context_attention_weights=None, attention_mask=None )
This is where the layer's logic lives.
Note here that
call() method in
tf.keras is little bit different
keras API. In
keras API, you can pass support masking for
layers as additional arguments. Whereas
method to support masking.
||Input tensor, or list/tuple of input tensors.|
||Additional positional arguments. Currently unused.|
||Additional keyword arguments. Currently unused.|
|A tensor or list/tuple of tensors.|