tfm.nlp.layers.MultiHeadRelativeAttention

A multi-head attention layer with relative attention + position encoding.

This layer shares the same input/output projections as the common tf.keras.layers.MultiHeadAttention layer.

When it calculates attention logits, position encoding is projected to form relative keys. The logits are composed by shifted relative logits and content logits.

query Query Tensor of shape [B, T, dim].
value Value Tensor of shape [B, S, dim].
content_attention_bias Bias Tensor for content based attention of shape [num_heads, dim].
positional_attention_bias Bias Tensor for position based attention of shape [num_heads, dim].
key Optional key Tensor of shape [B, S, dim]. If not given, will use value for both key and value, which is the most common case.
relative_position_encoding Relative positional encoding Tensor of shape [B, L, dim].
segment_matrix Optional Tensor representing segmentation IDs used in XLNet of shape [B, S, S + M].
segment_encoding Optional Tensor representing the segmentation encoding as used in XLNet of shape [2, num_heads, dim].
segment_attention_bias Optional trainable bias parameter added to the query had when calculating the segment-based attention score used in XLNet of shape [num_heads, dim].
state Optional Tensor of shape [B, M, E] where M is the length of the state or memory. If passed, this is also attended over as in Transformer XL.
attention_mask A boolean mask of shape [B, T, S] that prevents attention to certain positions.

kernel_initializer The kernel initializer. Defaults to variance_scaling.

Methods

call

View source

Compute multi-head relative attention over inputs.

Size glossary

  • Number of heads (H): the number of attention heads.
  • Value size (V): the size of each value embedding per head.
  • Key size (K): the size of each key embedding per head. Equally, the size of each query embedding per head. Typically K <= V.
  • Batch dimensions (B).
  • Query (target) attention axes shape (T).
  • Value (source) attention axes shape (S), the rank must match the target.
  • Encoding length (L): The relative positional encoding length.

Args
query attention input.
value attention input.
content_attention_bias A trainable bias parameter added to the query head when calculating the content-based attention score.
positional_attention_bias A trainable bias parameter added to the query head when calculating the position-based attention score.
key attention input.
relative_position_encoding relative positional encoding for key and value.
segment_matrix Optional Tensor representing segmentation IDs used in XLNet.
segment_encoding Optional Tensor representing the segmentation encoding as used in XLNet.
segment_attention_bias Optional trainable bias parameter added to the query had when calculating the segment-based attention score used in XLNet.
state (default None) optional state. If passed, this is also attended over as in TransformerXL.
attention_mask (default None) Optional mask that is added to attention logits. If state is not None, the mask source sequence dimension should extend M.

Returns
attention_output The result of the computation, of shape [B, T, E], where T is for target sequence shapes and E is the query input last dimension if output_shape is None. Otherwise, the multi-head outputs are projected to the shape specified by output_shape.

compute_attention

View source

Computes the attention.

This function defines the computation inside call with projected multihead Q, K, V, R inputs.

Args
query Projected query Tensor of shape [B, T, N, key_dim].
key Projected key Tensor of shape [B, S + M, N, key_dim].
value Projected value Tensor of shape [B, S + M, N, key_dim].
position Projected position Tensor of shape [B, L, N, key_dim].
content_attention_bias Trainable bias parameter added to the query head when calculating the content-based attention score.
positional_attention_bias Trainable bias parameter added to the query head when calculating the position-based attention score.
segment_matrix Optional Tensor representing segmentation IDs used in XLNet.
segment_encoding Optional trainable Tensor representing the segmentation encoding as used in XLNet.
segment_attention_bias Optional trainable bias parameter added to the query had when calculating the segment-based attention score used in XLNet.
attention_mask (default None) Optional mask that is added to attention logits. If state is not None, the mask source sequence dimension should extend M.

Returns
attention_output Multi-headed output of attention computation of shape [B, S, N, key_dim].