Have a question? Connect with the community at the TensorFlow Forum Visit Forum


Multi-channel Attention layer.

Inherits From: MultiHeadAttention

Introduced in, Generating Representative Headlines for News Stories . Expects multiple cross-attention target sequences.

Call args:

  • query: Query Tensor of shape [B, T, dim].
  • value: Value Tensor of shape [B, A, S, dim], where A denotes the
  • context_attention_weights: Context weights of shape [B, N, T, A], where N is the number of attention heads. Combines multi-channel sources context tensors according to the distribution among channels.
  • key: Optional key Tensor of shape [B, A, S, dim]. If not given, will use value for both key and value, which is the most common case.
  • attention_mask: a boolean mask of shape [B, T, S], that prevents attention to certain positions.



View source

This is where the layer's logic lives.

Note here that call() method in tf.keras is little bit different from keras API. In keras API, you can pass support masking for layers as additional arguments. Whereas tf.keras has compute_mask() method to support masking.

inputs Input tensor, or list/tuple of input tensors.
*args Additional positional arguments. Currently unused.
**kwargs Additional keyword arguments. Currently unused.

A tensor or list/tuple of tensors.