Transformer model for language understanding

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a Transformer model to translate Portuguese to English. This is an advanced example that assumes knowledge of text generation and attention.

The core idea behind the Transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. Transformer creates stacks of self-attention layers and is explained below in the sections Scaled dot product attention and Multi-head attention.

A transformer model handles variable-sized input using stacks of self-attention layers instead of RNNs or CNNs. This general architecture has a number of advantages:

  • It make no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, StarCraft units).
  • Layer outputs can be calculated in parallel, instead of a series like an RNN.
  • Distant items can affect each other's output without passing through many RNN-steps, or convolution layers (see Scene Memory Transformer for example).
  • It can learn long-range dependencies. This is a challenge in many sequence tasks.

The downsides of this architecture are:

  • For a time-series, the output for a time-step is calculated from the entire history instead of only the inputs and current hidden-state. This may be less efficient.
  • If the input does have a temporal/spatial relationship, like text, some positional encoding must be added or the model will effectively see a bag of words.

After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.

Attention heatmap

pip install -q tfds-nightly

# Pin matplotlib version to 3.2.2 since in the latest version
# transformer.ipynb fails with the following error:
# https://stackoverflow.com/questions/62953704/valueerror-the-number-of-fixedlocator-locations-5-usually-from-a-call-to-set
pip install matplotlib==3.2.2
Collecting matplotlib==3.2.2
  Downloading matplotlib-3.2.2-cp36-cp36m-manylinux1_x86_64.whl (12.4 MB)
[K     |████████████████████████████████| 12.4 MB 2.8 MB/s 
[?25hRequirement already satisfied: numpy>=1.11 in /tmpfs/src/tf_docs_env/lib/python3.6/site-packages (from matplotlib==3.2.2) (1.18.5)
Requirement already satisfied: cycler>=0.10 in /home/kbuilder/.local/lib/python3.6/site-packages (from matplotlib==3.2.2) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /home/kbuilder/.local/lib/python3.6/site-packages (from matplotlib==3.2.2) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /home/kbuilder/.local/lib/python3.6/site-packages (from matplotlib==3.2.2) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in /home/kbuilder/.local/lib/python3.6/site-packages (from matplotlib==3.2.2) (2.8.1)
Requirement already satisfied: six in /home/kbuilder/.local/lib/python3.6/site-packages (from cycler>=0.10->matplotlib==3.2.2) (1.15.0)
Installing collected packages: matplotlib
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.3.2
    Not uninstalling matplotlib at /home/kbuilder/.local/lib/python3.6/site-packages, outside environment /tmpfs/src/tf_docs_env
    Can't uninstall 'matplotlib'. No files were found to uninstall.
Successfully installed matplotlib-3.2.2

import tensorflow_datasets as tfds
import tensorflow as tf

import time
import numpy as np
import matplotlib.pyplot as plt

Setup input pipeline

Use TFDS to load the Portugese-English translation dataset from the TED Talks Open Translation Project.

This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
                               as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']
Downloading and preparing dataset ted_hrlr_translate/pt_to_en/1.0.0 (download: 124.94 MiB, generated: Unknown size, total: 124.94 MiB) to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0.incomplete9RVVAD/ted_hrlr_translate-train.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0.incomplete9RVVAD/ted_hrlr_translate-validation.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0.incomplete9RVVAD/ted_hrlr_translate-test.tfrecord
Dataset ted_hrlr_translate downloaded and prepared to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0. Subsequent calls will reuse this data.

Create a custom subwords tokenizer from the training dataset.

tokenizer_en = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
    (en.numpy() for pt, en in train_examples), target_vocab_size=2**13)

tokenizer_pt = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
    (pt.numpy() for pt, en in train_examples), target_vocab_size=2**13)
sample_string = 'Transformer is awesome.'

tokenized_string = tokenizer_en.encode(sample_string)
print ('Tokenized string is {}'.format(tokenized_string))

original_string = tokenizer_en.decode(tokenized_string)
print ('The original string: {}'.format(original_string))

assert original_string == sample_string
Tokenized string is [7915, 1248, 7946, 7194, 13, 2799, 7877]
The original string: Transformer is awesome.

The tokenizer encodes the string by breaking it into subwords if the word is not in its dictionary.

for ts in tokenized_string:
  print ('{} ----> {}'.format(ts, tokenizer_en.decode([ts])))
7915 ----> T
1248 ----> ran
7946 ----> s
7194 ----> former 
13 ----> is 
2799 ----> awesome
7877 ----> .

BUFFER_SIZE = 20000
BATCH_SIZE = 64

Add a start and end token to the input and target.

def encode(lang1, lang2):
  lang1 = [tokenizer_pt.vocab_size] + tokenizer_pt.encode(
      lang1.numpy()) + [tokenizer_pt.vocab_size+1]

  lang2 = [tokenizer_en.vocab_size] + tokenizer_en.encode(
      lang2.numpy()) + [tokenizer_en.vocab_size+1]

  return lang1, lang2

You want to use Dataset.map to apply this function to each element of the dataset. Dataset.map runs in graph mode.

  • Graph tensors do not have a value.
  • In graph mode you can only use TensorFlow Ops and functions.

So you can't .map this function directly: You need to wrap it in a tf.py_function. The tf.py_function will pass regular tensors (with a value and a .numpy() method to access it), to the wrapped python function.

def tf_encode(pt, en):
  result_pt, result_en = tf.py_function(encode, [pt, en], [tf.int64, tf.int64])
  result_pt.set_shape([None])
  result_en.set_shape([None])

  return result_pt, result_en
MAX_LENGTH = 40
def filter_max_length(x, y, max_length=MAX_LENGTH):
  return tf.logical_and(tf.size(x) <= max_length,
                        tf.size(y) <= max_length)
train_dataset = train_examples.map(tf_encode)
train_dataset = train_dataset.filter(filter_max_length)
# cache the dataset to memory to get a speedup while reading from it.
train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)


val_dataset = val_examples.map(tf_encode)
val_dataset = val_dataset.filter(filter_max_length).padded_batch(BATCH_SIZE)
pt_batch, en_batch = next(iter(val_dataset))
pt_batch, en_batch
(<tf.Tensor: shape=(64, 38), dtype=int64, numpy=
 array([[8214,  342, 3032, ...,    0,    0,    0],
        [8214,   95,  198, ...,    0,    0,    0],
        [8214, 4479, 7990, ...,    0,    0,    0],
        ...,
        [8214,  584,   12, ...,    0,    0,    0],
        [8214,   59, 1548, ...,    0,    0,    0],
        [8214,  118,   34, ...,    0,    0,    0]])>,
 <tf.Tensor: shape=(64, 40), dtype=int64, numpy=
 array([[8087,   98,   25, ...,    0,    0,    0],
        [8087,   12,   20, ...,    0,    0,    0],
        [8087,   12, 5453, ...,    0,    0,    0],
        ...,
        [8087,   18, 2059, ...,    0,    0,    0],
        [8087,   16, 1436, ...,    0,    0,    0],
        [8087,   15,   57, ...,    0,    0,    0]])>)

Positional encoding

Since this model doesn't contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence.

The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of words in a sentence. So after adding the positional encoding, words will be closer to each other based on the similarity of their meaning and their position in the sentence, in the d-dimensional space.

See the notebook on positional encoding to learn more about it. The formula for calculating the positional encoding is as follows:

$$\Large{PE_{(pos, 2i)} = sin(pos / 10000^{2i / d_{model} })} $$
$$\Large{PE_{(pos, 2i+1)} = cos(pos / 10000^{2i / d_{model} })} $$
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates
def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)
pos_encoding = positional_encoding(50, 512)
print (pos_encoding.shape)

plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('Depth')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()
(1, 50, 512)

png

Masking

Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. The mask indicates where pad value 0 is present: it outputs a 1 at those locations, and a 0 otherwise.

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)
<tf.Tensor: shape=(3, 1, 1, 5), dtype=float32, numpy=
array([[[[0., 0., 1., 1., 0.]]],


       [[[0., 0., 0., 1., 1.]]],


       [[[1., 1., 1., 0., 0.]]]], dtype=float32)>

The look-ahead mask is used to mask the future tokens in a sequence. In other words, the mask indicates which entries should not be used.

This means that to predict the third word, only the first and second word will be used. Similarly to predict the fourth word, only the first, second and the third word will be used and so on.

def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.]], dtype=float32)>

Scaled dot product attention

scaled_dot_product_attention

The attention function used by the transformer takes three inputs: Q (query), K (key), V (value). The equation used to calculate the attention weights is:

$$\Large{Attention(Q, K, V) = softmax_k(\frac{QK^T}{\sqrt{d_k} }) V} $$

The dot-product attention is scaled by a factor of square root of the depth. This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax.

For example, consider that Q and K have a mean of 0 and variance of 1. Their matrix multiplication will have a mean of 0 and variance of dk. Hence, square root of dk is used for scaling (and not any other number) because the matmul of Q and K should have a mean of 0 and variance of 1, and you get a gentler softmax.

The mask is multiplied with -1e9 (close to negative infinity). This is done because the mask is summed with the scaled matrix multiplication of Q and K and is applied immediately before a softmax. The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.

def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

As the softmax normalization is done on K, its values decide the amount of importance given to Q.

The output represents the multiplication of the attention weights and the V (value) vector. This ensures that the words you want to focus on are kept as-is and the irrelevant words are flushed out.

def print_out(q, k, v):
  temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
  print ('Attention weights are:')
  print (temp_attn)
  print ('Output is:')
  print (temp_out)
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)

# This query aligns with a repeated key (third and fourth), 
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.  0.  0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[550.    5.5]], shape=(1, 2), dtype=float32)

# This query aligns equally with the first and second key, 
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.5 0.5 0.  0. ]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)

Pass all the queries together.

temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32)  # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor(
[[0.  0.  0.5 0.5]
 [0.  1.  0.  0. ]
 [0.5 0.5 0.  0. ]], shape=(3, 4), dtype=float32)
Output is:
tf.Tensor(
[[550.    5.5]
 [ 10.    0. ]
 [  5.5   0. ]], shape=(3, 2), dtype=float32)

Multi-head attention

multi-head attention

Multi-head attention consists of four parts:

  • Linear layers and split into heads.
  • Scaled dot-product attention.
  • Concatenation of heads.
  • Final linear layer.

Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers and split up into multiple heads.

The scaled_dot_product_attention defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step. The attention output for each head is then concatenated (using tf.transpose, and tf.reshape) and put through a final Dense layer.

Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights

Create a MultiHeadAttention layer to try out. At each location in the sequence, y, the MultiHeadAttention runs all 8 attention heads across all other locations in the sequence, returning a new vector of the same length at each location.

temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))

Point wise feed forward network

Point wise feed forward network consists of two fully-connected layers with a ReLU activation in between.

def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape
TensorShape([64, 50, 512])

Encoder and decoder

transformer

The transformer model follows the same general pattern as a standard sequence to sequence with attention model.

  • The input sentence is passed through N encoder layers that generates an output for each word/token in the sequence.
  • The decoder attends on the encoder's output and its own input (self-attention) to predict the next word.

Encoder layer

Each encoder layer consists of sublayers:

  1. Multi-head attention (with padding mask)
  2. Point wise feed forward networks.

Each of these sublayers has a residual connection around it followed by a layer normalization. Residual connections help in avoiding the vanishing gradient problem in deep networks.

The output of each sublayer is LayerNorm(x + Sublayer(x)). The normalization is done on the d_model (last) axis. There are N encoder layers in the transformer.

class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2
sample_encoder_layer = EncoderLayer(512, 8, 2048)

sample_encoder_layer_output = sample_encoder_layer(
    tf.random.uniform((64, 43, 512)), False, None)

sample_encoder_layer_output.shape  # (batch_size, input_seq_len, d_model)
TensorShape([64, 43, 512])

Decoder layer

Each decoder layer consists of sublayers:

  1. Masked multi-head attention (with look ahead mask and padding mask)
  2. Multi-head attention (with padding mask). V (value) and K (key) receive the encoder output as inputs. Q (query) receives the output from the masked multi-head attention sublayer.
  3. Point wise feed forward networks

Each of these sublayers has a residual connection around it followed by a layer normalization. The output of each sublayer is LayerNorm(x + Sublayer(x)). The normalization is done on the d_model (last) axis.

There are N decoder layers in the transformer.

As Q receives the output from decoder's first attention block, and K receives the encoder output, the attention weights represent the importance given to the decoder's input based on the encoder's output. In other words, the decoder predicts the next word by looking at the encoder output and self-attending to its own output. See the demonstration above in the scaled dot product attention section.

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)


  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)

    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

    return out3, attn_weights_block1, attn_weights_block2
sample_decoder_layer = DecoderLayer(512, 8, 2048)

sample_decoder_layer_output, _, _ = sample_decoder_layer(
    tf.random.uniform((64, 50, 512)), sample_encoder_layer_output, 
    False, None, None)

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)
TensorShape([64, 50, 512])

Encoder

The Encoder consists of:

  1. Input Embedding
  2. Positional Encoding
  3. N encoder layers

The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.

class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)


    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8, 
                         dff=2048, input_vocab_size=8500,
                         maximum_position_encoding=10000)
temp_input = tf.random.uniform((64, 62), dtype=tf.int64, minval=0, maxval=200)

sample_encoder_output = sample_encoder(temp_input, training=False, mask=None)

print (sample_encoder_output.shape)  # (batch_size, input_seq_len, d_model)
(64, 62, 512)

Decoder

The Decoder consists of:

  1. Output Embedding
  2. Positional Encoding
  3. N decoder layers

The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.

class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8, 
                         dff=2048, target_vocab_size=8000,
                         maximum_position_encoding=5000)
temp_input = tf.random.uniform((64, 26), dtype=tf.int64, minval=0, maxval=200)

output, attn = sample_decoder(temp_input, 
                              enc_output=sample_encoder_output, 
                              training=False,
                              look_ahead_mask=None, 
                              padding_mask=None)

output.shape, attn['decoder_layer2_block2'].shape
(TensorShape([64, 26, 512]), TensorShape([64, 8, 26, 62]))

Create the Transformer

Transformer consists of the encoder, decoder and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.

class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super(Transformer, self).__init__()

    self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                           input_vocab_size, pe_input, rate)

    self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                           target_vocab_size, pe_target, rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inp, tar, training, enc_padding_mask, 
           look_ahead_mask, dec_padding_mask):

    enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)

    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, dec_padding_mask)

    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

    return final_output, attention_weights
sample_transformer = Transformer(
    num_layers=2, d_model=512, num_heads=8, dff=2048, 
    input_vocab_size=8500, target_vocab_size=8000, 
    pe_input=10000, pe_target=6000)

temp_input = tf.random.uniform((64, 38), dtype=tf.int64, minval=0, maxval=200)
temp_target = tf.random.uniform((64, 36), dtype=tf.int64, minval=0, maxval=200)

fn_out, _ = sample_transformer(temp_input, temp_target, training=False, 
                               enc_padding_mask=None, 
                               look_ahead_mask=None,
                               dec_padding_mask=None)

fn_out.shape  # (batch_size, tar_seq_len, target_vocab_size)
TensorShape([64, 36, 8000])

Set hyperparameters

To keep this example small and relatively fast, the values for num_layers, d_model, and dff have been reduced.

The values used in the base model of transformer were; num_layers=6, d_model = 512, dff = 2048. See the paper for all the other versions of the transformer.

num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = tokenizer_pt.vocab_size + 2
target_vocab_size = tokenizer_en.vocab_size + 2
dropout_rate = 0.1

Optimizer

Use the Adam optimizer with a custom learning rate scheduler according to the formula in the paper.

$$\Large{lrate = d_{model}^{-0.5} * min(step{\_}num^{-0.5}, step{\_}num * warmup{\_}steps^{-1.5})}$$
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)
temp_learning_rate_schedule = CustomSchedule(d_model)

plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")
Text(0.5, 0, 'Train Step')

png

Loss and metrics

Since the target sequences are padded, it is important to apply a padding mask when calculating the loss.

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_sum(loss_)/tf.reduce_sum(mask)


def accuracy_function(real, pred):
  accuracies = tf.equal(real, tf.argmax(pred, axis=2))

  mask = tf.math.logical_not(tf.math.equal(real, 0))
  accuracies = tf.math.logical_and(mask, accuracies)

  accuracies = tf.cast(accuracies, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

Training and checkpointing

transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)
def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)

  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)

  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by 
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

  return enc_padding_mask, combined_mask, dec_padding_mask

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs.

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

The target is divided into tar_inp and tar_real. tar_inp is passed as an input to the decoder. tar_real is that same input shifted by 1: At each location in tar_input, tar_real contains the next token that should be predicted.

For example, sentence = "SOS A lion in the jungle is sleeping EOS"

tar_inp = "SOS A lion in the jungle is sleeping"

tar_real = "A lion in the jungle is sleeping EOS"

The transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next.

During training this example uses teacher-forcing (like in the text generation tutorial). Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.

As the transformer predicts each word, self-attention allows it to look at the previous words in the input sequence to better predict the next word.

To prevent the model from peeking at the expected output the model uses a look-ahead mask.

EPOCHS = 20
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

  with tf.GradientTape() as tape:
    predictions, _ = transformer(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(accuracy_function(tar_real, predictions))

Portuguese is used as the input language and English is the target language.

for epoch in range(EPOCHS):
  start = time.time()

  train_loss.reset_states()
  train_accuracy.reset_states()

  # inp -> portuguese, tar -> english
  for (batch, (inp, tar)) in enumerate(train_dataset):
    train_step(inp, tar)

    if batch % 50 == 0:
      print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
          epoch + 1, batch, train_loss.result(), train_accuracy.result()))

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))

  print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
Epoch 1 Batch 0 Loss 9.0198 Accuracy 0.0000
Epoch 1 Batch 50 Loss 8.9566 Accuracy 0.0060
Epoch 1 Batch 100 Loss 8.8567 Accuracy 0.0306
Epoch 1 Batch 150 Loss 8.7498 Accuracy 0.0381
Epoch 1 Batch 200 Loss 8.6213 Accuracy 0.0431
Epoch 1 Batch 250 Loss 8.4676 Accuracy 0.0528
Epoch 1 Batch 300 Loss 8.2941 Accuracy 0.0630
Epoch 1 Batch 350 Loss 8.1103 Accuracy 0.0710
Epoch 1 Batch 400 Loss 7.9283 Accuracy 0.0773
Epoch 1 Batch 450 Loss 7.7631 Accuracy 0.0825
Epoch 1 Batch 500 Loss 7.6183 Accuracy 0.0881
Epoch 1 Batch 550 Loss 7.4868 Accuracy 0.0939
Epoch 1 Batch 600 Loss 7.3640 Accuracy 0.1007
Epoch 1 Batch 650 Loss 7.2441 Accuracy 0.1076
Epoch 1 Batch 700 Loss 7.1291 Accuracy 0.1146
Epoch 1 Loss 7.1246 Accuracy 0.1149
Time taken for 1 epoch: 356.5545802116394 secs

Epoch 2 Batch 0 Loss 5.3808 Accuracy 0.2259
Epoch 2 Batch 50 Loss 5.4921 Accuracy 0.2191
Epoch 2 Batch 100 Loss 5.4470 Accuracy 0.2231
Epoch 2 Batch 150 Loss 5.4008 Accuracy 0.2274
Epoch 2 Batch 200 Loss 5.3479 Accuracy 0.2311
Epoch 2 Batch 250 Loss 5.3001 Accuracy 0.2356
Epoch 2 Batch 300 Loss 5.2542 Accuracy 0.2389
Epoch 2 Batch 350 Loss 5.2172 Accuracy 0.2424
Epoch 2 Batch 400 Loss 5.1819 Accuracy 0.2459
Epoch 2 Batch 450 Loss 5.1527 Accuracy 0.2487
Epoch 2 Batch 500 Loss 5.1213 Accuracy 0.2519
Epoch 2 Batch 550 Loss 5.0927 Accuracy 0.2546
Epoch 2 Batch 600 Loss 5.0653 Accuracy 0.2572
Epoch 2 Batch 650 Loss 5.0412 Accuracy 0.2595
Epoch 2 Batch 700 Loss 5.0185 Accuracy 0.2620
Epoch 2 Loss 5.0170 Accuracy 0.2621
Time taken for 1 epoch: 329.01824021339417 secs

Epoch 3 Batch 0 Loss 4.6641 Accuracy 0.2917
Epoch 3 Batch 50 Loss 4.6428 Accuracy 0.2983
Epoch 3 Batch 100 Loss 4.6202 Accuracy 0.3008
Epoch 3 Batch 150 Loss 4.6122 Accuracy 0.3018
Epoch 3 Batch 200 Loss 4.5992 Accuracy 0.3031
Epoch 3 Batch 250 Loss 4.5891 Accuracy 0.3047
Epoch 3 Batch 300 Loss 4.5715 Accuracy 0.3064
Epoch 3 Batch 350 Loss 4.5585 Accuracy 0.3076
Epoch 3 Batch 400 Loss 4.5441 Accuracy 0.3094
Epoch 3 Batch 450 Loss 4.5279 Accuracy 0.3112
Epoch 3 Batch 500 Loss 4.5145 Accuracy 0.3128
Epoch 3 Batch 550 Loss 4.5004 Accuracy 0.3142
Epoch 3 Batch 600 Loss 4.4880 Accuracy 0.3159
Epoch 3 Batch 650 Loss 4.4744 Accuracy 0.3177
Epoch 3 Batch 700 Loss 4.4601 Accuracy 0.3193
Epoch 3 Loss 4.4592 Accuracy 0.3194
Time taken for 1 epoch: 329.6653571128845 secs

Epoch 4 Batch 0 Loss 4.1498 Accuracy 0.3394
Epoch 4 Batch 50 Loss 4.1801 Accuracy 0.3493
Epoch 4 Batch 100 Loss 4.1397 Accuracy 0.3521
Epoch 4 Batch 150 Loss 4.1312 Accuracy 0.3539
Epoch 4 Batch 200 Loss 4.1132 Accuracy 0.3569
Epoch 4 Batch 250 Loss 4.1008 Accuracy 0.3584
Epoch 4 Batch 300 Loss 4.0835 Accuracy 0.3612
Epoch 4 Batch 350 Loss 4.0667 Accuracy 0.3635
Epoch 4 Batch 400 Loss 4.0514 Accuracy 0.3653
Epoch 4 Batch 450 Loss 4.0328 Accuracy 0.3679
Epoch 4 Batch 500 Loss 4.0187 Accuracy 0.3701
Epoch 4 Batch 550 Loss 4.0050 Accuracy 0.3720
Epoch 4 Batch 600 Loss 3.9889 Accuracy 0.3741
Epoch 4 Batch 650 Loss 3.9719 Accuracy 0.3764
Epoch 4 Batch 700 Loss 3.9573 Accuracy 0.3787
Epoch 4 Loss 3.9571 Accuracy 0.3787
Time taken for 1 epoch: 329.5320518016815 secs

Epoch 5 Batch 0 Loss 3.5488 Accuracy 0.4171
Epoch 5 Batch 50 Loss 3.6103 Accuracy 0.4169
Epoch 5 Batch 100 Loss 3.5825 Accuracy 0.4221
Epoch 5 Batch 150 Loss 3.5914 Accuracy 0.4221
Epoch 5 Batch 200 Loss 3.5879 Accuracy 0.4233
Epoch 5 Batch 250 Loss 3.5773 Accuracy 0.4248
Epoch 5 Batch 300 Loss 3.5681 Accuracy 0.4263
Epoch 5 Batch 350 Loss 3.5531 Accuracy 0.4286
Epoch 5 Batch 400 Loss 3.5416 Accuracy 0.4302
Epoch 5 Batch 450 Loss 3.5302 Accuracy 0.4319
Epoch 5 Batch 500 Loss 3.5215 Accuracy 0.4332
Epoch 5 Batch 550 Loss 3.5112 Accuracy 0.4347
Epoch 5 Batch 600 Loss 3.5006 Accuracy 0.4361
Epoch 5 Batch 650 Loss 3.4935 Accuracy 0.4369
Epoch 5 Batch 700 Loss 3.4868 Accuracy 0.4376
Saving checkpoint for epoch 5 at ./checkpoints/train/ckpt-1
Epoch 5 Loss 3.4866 Accuracy 0.4377
Time taken for 1 epoch: 329.0690224170685 secs

Epoch 6 Batch 0 Loss 3.2325 Accuracy 0.4480
Epoch 6 Batch 50 Loss 3.2148 Accuracy 0.4647
Epoch 6 Batch 100 Loss 3.1823 Accuracy 0.4689
Epoch 6 Batch 150 Loss 3.1713 Accuracy 0.4697
Epoch 6 Batch 200 Loss 3.1619 Accuracy 0.4717
Epoch 6 Batch 250 Loss 3.1487 Accuracy 0.4737
Epoch 6 Batch 300 Loss 3.1402 Accuracy 0.4750
Epoch 6 Batch 350 Loss 3.1383 Accuracy 0.4749
Epoch 6 Batch 400 Loss 3.1291 Accuracy 0.4761
Epoch 6 Batch 450 Loss 3.1239 Accuracy 0.4766
Epoch 6 Batch 500 Loss 3.1190 Accuracy 0.4773
Epoch 6 Batch 550 Loss 3.1155 Accuracy 0.4775
Epoch 6 Batch 600 Loss 3.1079 Accuracy 0.4786
Epoch 6 Batch 650 Loss 3.0996 Accuracy 0.4798
Epoch 6 Batch 700 Loss 3.0910 Accuracy 0.4809
Epoch 6 Loss 3.0901 Accuracy 0.4810
Time taken for 1 epoch: 329.31779623031616 secs

Epoch 7 Batch 0 Loss 2.6800 Accuracy 0.5217
Epoch 7 Batch 50 Loss 2.7518 Accuracy 0.5169
Epoch 7 Batch 100 Loss 2.7756 Accuracy 0.5133
Epoch 7 Batch 150 Loss 2.7687 Accuracy 0.5150
Epoch 7 Batch 200 Loss 2.7671 Accuracy 0.5156
Epoch 7 Batch 250 Loss 2.7552 Accuracy 0.5175
Epoch 7 Batch 300 Loss 2.7445 Accuracy 0.5193
Epoch 7 Batch 350 Loss 2.7383 Accuracy 0.5202
Epoch 7 Batch 400 Loss 2.7309 Accuracy 0.5213
Epoch 7 Batch 450 Loss 2.7273 Accuracy 0.5217
Epoch 7 Batch 500 Loss 2.7227 Accuracy 0.5224
Epoch 7 Batch 550 Loss 2.7165 Accuracy 0.5233
Epoch 7 Batch 600 Loss 2.7098 Accuracy 0.5242
Epoch 7 Batch 650 Loss 2.7059 Accuracy 0.5248
Epoch 7 Batch 700 Loss 2.7015 Accuracy 0.5255
Epoch 7 Loss 2.7009 Accuracy 0.5255
Time taken for 1 epoch: 328.4230201244354 secs

Epoch 8 Batch 0 Loss 2.4337 Accuracy 0.5464
Epoch 8 Batch 50 Loss 2.3840 Accuracy 0.5613
Epoch 8 Batch 100 Loss 2.3731 Accuracy 0.5639
Epoch 8 Batch 150 Loss 2.3805 Accuracy 0.5639
Epoch 8 Batch 200 Loss 2.3819 Accuracy 0.5637
Epoch 8 Batch 250 Loss 2.3868 Accuracy 0.5625
Epoch 8 Batch 300 Loss 2.3867 Accuracy 0.5625
Epoch 8 Batch 350 Loss 2.3840 Accuracy 0.5631
Epoch 8 Batch 400 Loss 2.3814 Accuracy 0.5637
Epoch 8 Batch 450 Loss 2.3806 Accuracy 0.5638
Epoch 8 Batch 500 Loss 2.3784 Accuracy 0.5640
Epoch 8 Batch 550 Loss 2.3755 Accuracy 0.5646
Epoch 8 Batch 600 Loss 2.3763 Accuracy 0.5646
Epoch 8 Batch 650 Loss 2.3762 Accuracy 0.5649
Epoch 8 Batch 700 Loss 2.3753 Accuracy 0.5652
Epoch 8 Loss 2.3750 Accuracy 0.5653
Time taken for 1 epoch: 328.0940201282501 secs

Epoch 9 Batch 0 Loss 2.1380 Accuracy 0.5907
Epoch 9 Batch 50 Loss 2.1332 Accuracy 0.5948
Epoch 9 Batch 100 Loss 2.1317 Accuracy 0.5945
Epoch 9 Batch 150 Loss 2.1257 Accuracy 0.5951
Epoch 9 Batch 200 Loss 2.1219 Accuracy 0.5961
Epoch 9 Batch 250 Loss 2.1259 Accuracy 0.5958
Epoch 9 Batch 300 Loss 2.1301 Accuracy 0.5954
Epoch 9 Batch 350 Loss 2.1341 Accuracy 0.5952
Epoch 9 Batch 400 Loss 2.1332 Accuracy 0.5954
Epoch 9 Batch 450 Loss 2.1361 Accuracy 0.5950
Epoch 9 Batch 500 Loss 2.1391 Accuracy 0.5945
Epoch 9 Batch 550 Loss 2.1402 Accuracy 0.5945
Epoch 9 Batch 600 Loss 2.1391 Accuracy 0.5948
Epoch 9 Batch 650 Loss 2.1411 Accuracy 0.5946
Epoch 9 Batch 700 Loss 2.1419 Accuracy 0.5947
Epoch 9 Loss 2.1421 Accuracy 0.5947
Time taken for 1 epoch: 328.67867064476013 secs

Epoch 10 Batch 0 Loss 2.1447 Accuracy 0.5654
Epoch 10 Batch 50 Loss 1.9188 Accuracy 0.6214
Epoch 10 Batch 100 Loss 1.9275 Accuracy 0.6203
Epoch 10 Batch 150 Loss 1.9255 Accuracy 0.6217
Epoch 10 Batch 200 Loss 1.9352 Accuracy 0.6202
Epoch 10 Batch 250 Loss 1.9391 Accuracy 0.6197
Epoch 10 Batch 300 Loss 1.9473 Accuracy 0.6189
Epoch 10 Batch 350 Loss 1.9503 Accuracy 0.6185
Epoch 10 Batch 400 Loss 1.9528 Accuracy 0.6179
Epoch 10 Batch 450 Loss 1.9583 Accuracy 0.6170
Epoch 10 Batch 500 Loss 1.9615 Accuracy 0.6168
Epoch 10 Batch 550 Loss 1.9638 Accuracy 0.6169
Epoch 10 Batch 600 Loss 1.9646 Accuracy 0.6169
Epoch 10 Batch 650 Loss 1.9658 Accuracy 0.6169
Epoch 10 Batch 700 Loss 1.9674 Accuracy 0.6167
Saving checkpoint for epoch 10 at ./checkpoints/train/ckpt-2
Epoch 10 Loss 1.9675 Accuracy 0.6167
Time taken for 1 epoch: 328.0170452594757 secs

Epoch 11 Batch 0 Loss 1.9735 Accuracy 0.6258
Epoch 11 Batch 50 Loss 1.7728 Accuracy 0.6423
Epoch 11 Batch 100 Loss 1.7802 Accuracy 0.6410
Epoch 11 Batch 150 Loss 1.7852 Accuracy 0.6403
Epoch 11 Batch 200 Loss 1.7945 Accuracy 0.6387
Epoch 11 Batch 250 Loss 1.7977 Accuracy 0.6379
Epoch 11 Batch 300 Loss 1.7942 Accuracy 0.6385
Epoch 11 Batch 350 Loss 1.8017 Accuracy 0.6374
Epoch 11 Batch 400 Loss 1.8012 Accuracy 0.6378
Epoch 11 Batch 450 Loss 1.8063 Accuracy 0.6371
Epoch 11 Batch 500 Loss 1.8089 Accuracy 0.6369
Epoch 11 Batch 550 Loss 1.8126 Accuracy 0.6365
Epoch 11 Batch 600 Loss 1.8170 Accuracy 0.6360
Epoch 11 Batch 650 Loss 1.8226 Accuracy 0.6354
Epoch 11 Batch 700 Loss 1.8250 Accuracy 0.6352
Epoch 11 Loss 1.8256 Accuracy 0.6351
Time taken for 1 epoch: 327.8001685142517 secs

Epoch 12 Batch 0 Loss 1.6151 Accuracy 0.6718
Epoch 12 Batch 50 Loss 1.6312 Accuracy 0.6627
Epoch 12 Batch 100 Loss 1.6509 Accuracy 0.6588
Epoch 12 Batch 150 Loss 1.6678 Accuracy 0.6562
Epoch 12 Batch 200 Loss 1.6765 Accuracy 0.6549
Epoch 12 Batch 250 Loss 1.6849 Accuracy 0.6536
Epoch 12 Batch 300 Loss 1.6853 Accuracy 0.6532
Epoch 12 Batch 350 Loss 1.6876 Accuracy 0.6529
Epoch 12 Batch 400 Loss 1.6880 Accuracy 0.6531
Epoch 12 Batch 450 Loss 1.6916 Accuracy 0.6527
Epoch 12 Batch 500 Loss 1.6950 Accuracy 0.6522
Epoch 12 Batch 550 Loss 1.6998 Accuracy 0.6516
Epoch 12 Batch 600 Loss 1.7023 Accuracy 0.6512
Epoch 12 Batch 650 Loss 1.7068 Accuracy 0.6507
Epoch 12 Batch 700 Loss 1.7100 Accuracy 0.6503
Epoch 12 Loss 1.7105 Accuracy 0.6502
Time taken for 1 epoch: 329.3431029319763 secs

Epoch 13 Batch 0 Loss 1.5410 Accuracy 0.6902
Epoch 13 Batch 50 Loss 1.5296 Accuracy 0.6748
Epoch 13 Batch 100 Loss 1.5512 Accuracy 0.6718
Epoch 13 Batch 150 Loss 1.5605 Accuracy 0.6706
Epoch 13 Batch 200 Loss 1.5643 Accuracy 0.6705
Epoch 13 Batch 250 Loss 1.5695 Accuracy 0.6700
Epoch 13 Batch 300 Loss 1.5756 Accuracy 0.6691
Epoch 13 Batch 350 Loss 1.5782 Accuracy 0.6689
Epoch 13 Batch 400 Loss 1.5816 Accuracy 0.6681
Epoch 13 Batch 450 Loss 1.5883 Accuracy 0.6668
Epoch 13 Batch 500 Loss 1.5922 Accuracy 0.6663
Epoch 13 Batch 550 Loss 1.5949 Accuracy 0.6661
Epoch 13 Batch 600 Loss 1.5997 Accuracy 0.6655
Epoch 13 Batch 650 Loss 1.6056 Accuracy 0.6645
Epoch 13 Batch 700 Loss 1.6087 Accuracy 0.6644
Epoch 13 Loss 1.6087 Accuracy 0.6644
Time taken for 1 epoch: 329.06107211112976 secs

Epoch 14 Batch 0 Loss 1.3944 Accuracy 0.6995
Epoch 14 Batch 50 Loss 1.4568 Accuracy 0.6866
Epoch 14 Batch 100 Loss 1.4724 Accuracy 0.6843
Epoch 14 Batch 150 Loss 1.4806 Accuracy 0.6835
Epoch 14 Batch 200 Loss 1.4826 Accuracy 0.6830
Epoch 14 Batch 250 Loss 1.4882 Accuracy 0.6812
Epoch 14 Batch 300 Loss 1.4904 Accuracy 0.6811
Epoch 14 Batch 350 Loss 1.4919 Accuracy 0.6809
Epoch 14 Batch 400 Loss 1.4956 Accuracy 0.6806
Epoch 14 Batch 450 Loss 1.5005 Accuracy 0.6799
Epoch 14 Batch 500 Loss 1.5039 Accuracy 0.6795
Epoch 14 Batch 550 Loss 1.5082 Accuracy 0.6788
Epoch 14 Batch 600 Loss 1.5137 Accuracy 0.6779
Epoch 14 Batch 650 Loss 1.5200 Accuracy 0.6770
Epoch 14 Batch 700 Loss 1.5251 Accuracy 0.6763
Epoch 14 Loss 1.5255 Accuracy 0.6762
Time taken for 1 epoch: 328.8318474292755 secs

Epoch 15 Batch 0 Loss 1.5352 Accuracy 0.6574
Epoch 15 Batch 50 Loss 1.3418 Accuracy 0.7061
Epoch 15 Batch 100 Loss 1.3626 Accuracy 0.7019
Epoch 15 Batch 150 Loss 1.3738 Accuracy 0.6999
Epoch 15 Batch 200 Loss 1.3871 Accuracy 0.6971
Epoch 15 Batch 250 Loss 1.3942 Accuracy 0.6962
Epoch 15 Batch 300 Loss 1.4037 Accuracy 0.6945
Epoch 15 Batch 350 Loss 1.4079 Accuracy 0.6940
Epoch 15 Batch 400 Loss 1.4144 Accuracy 0.6931
Epoch 15 Batch 450 Loss 1.4214 Accuracy 0.6919
Epoch 15 Batch 500 Loss 1.4255 Accuracy 0.6914
Epoch 15 Batch 550 Loss 1.4304 Accuracy 0.6906
Epoch 15 Batch 600 Loss 1.4359 Accuracy 0.6896
Epoch 15 Batch 650 Loss 1.4403 Accuracy 0.6889
Epoch 15 Batch 700 Loss 1.4479 Accuracy 0.6878
Saving checkpoint for epoch 15 at ./checkpoints/train/ckpt-3
Epoch 15 Loss 1.4481 Accuracy 0.6878
Time taken for 1 epoch: 329.04406332969666 secs

Epoch 16 Batch 0 Loss 1.3499 Accuracy 0.6903
Epoch 16 Batch 50 Loss 1.2827 Accuracy 0.7132
Epoch 16 Batch 100 Loss 1.2885 Accuracy 0.7136
Epoch 16 Batch 150 Loss 1.3014 Accuracy 0.7119
Epoch 16 Batch 200 Loss 1.3165 Accuracy 0.7090
Epoch 16 Batch 250 Loss 1.3267 Accuracy 0.7072
Epoch 16 Batch 300 Loss 1.3326 Accuracy 0.7057
Epoch 16 Batch 350 Loss 1.3404 Accuracy 0.7041
Epoch 16 Batch 400 Loss 1.3477 Accuracy 0.7027
Epoch 16 Batch 450 Loss 1.3547 Accuracy 0.7017
Epoch 16 Batch 500 Loss 1.3585 Accuracy 0.7012
Epoch 16 Batch 550 Loss 1.3641 Accuracy 0.7003
Epoch 16 Batch 600 Loss 1.3696 Accuracy 0.6994
Epoch 16 Batch 650 Loss 1.3757 Accuracy 0.6987
Epoch 16 Batch 700 Loss 1.3806 Accuracy 0.6979
Epoch 16 Loss 1.3809 Accuracy 0.6979
Time taken for 1 epoch: 329.09080934524536 secs

Epoch 17 Batch 0 Loss 1.2575 Accuracy 0.7162
Epoch 17 Batch 50 Loss 1.2620 Accuracy 0.7161
Epoch 17 Batch 100 Loss 1.2560 Accuracy 0.7187
Epoch 17 Batch 150 Loss 1.2598 Accuracy 0.7174
Epoch 17 Batch 200 Loss 1.2668 Accuracy 0.7160
Epoch 17 Batch 250 Loss 1.2760 Accuracy 0.7143
Epoch 17 Batch 300 Loss 1.2784 Accuracy 0.7143
Epoch 17 Batch 350 Loss 1.2853 Accuracy 0.7133
Epoch 17 Batch 400 Loss 1.2901 Accuracy 0.7123
Epoch 17 Batch 450 Loss 1.2948 Accuracy 0.7119
Epoch 17 Batch 500 Loss 1.3009 Accuracy 0.7108
Epoch 17 Batch 550 Loss 1.3065 Accuracy 0.7097
Epoch 17 Batch 600 Loss 1.3117 Accuracy 0.7089
Epoch 17 Batch 650 Loss 1.3161 Accuracy 0.7082
Epoch 17 Batch 700 Loss 1.3205 Accuracy 0.7075
Epoch 17 Loss 1.3213 Accuracy 0.7074
Time taken for 1 epoch: 330.29657912254333 secs

Epoch 18 Batch 0 Loss 0.9995 Accuracy 0.7531
Epoch 18 Batch 50 Loss 1.1740 Accuracy 0.7312
Epoch 18 Batch 100 Loss 1.1851 Accuracy 0.7302
Epoch 18 Batch 150 Loss 1.1958 Accuracy 0.7280
Epoch 18 Batch 200 Loss 1.2056 Accuracy 0.7266
Epoch 18 Batch 250 Loss 1.2129 Accuracy 0.7251
Epoch 18 Batch 300 Loss 1.2198 Accuracy 0.7237
Epoch 18 Batch 350 Loss 1.2270 Accuracy 0.7223
Epoch 18 Batch 400 Loss 1.2320 Accuracy 0.7216
Epoch 18 Batch 450 Loss 1.2352 Accuracy 0.7208
Epoch 18 Batch 500 Loss 1.2431 Accuracy 0.7193
Epoch 18 Batch 550 Loss 1.2500 Accuracy 0.7181
Epoch 18 Batch 600 Loss 1.2569 Accuracy 0.7169
Epoch 18 Batch 650 Loss 1.2627 Accuracy 0.7158
Epoch 18 Batch 700 Loss 1.2681 Accuracy 0.7152
Epoch 18 Loss 1.2685 Accuracy 0.7152
Time taken for 1 epoch: 329.7291362285614 secs

Epoch 19 Batch 0 Loss 1.0314 Accuracy 0.7674
Epoch 19 Batch 50 Loss 1.1367 Accuracy 0.7379
Epoch 19 Batch 100 Loss 1.1340 Accuracy 0.7396
Epoch 19 Batch 150 Loss 1.1466 Accuracy 0.7365
Epoch 19 Batch 200 Loss 1.1553 Accuracy 0.7347
Epoch 19 Batch 250 Loss 1.1660 Accuracy 0.7327
Epoch 19 Batch 300 Loss 1.1708 Accuracy 0.7319
Epoch 19 Batch 350 Loss 1.1795 Accuracy 0.7299
Epoch 19 Batch 400 Loss 1.1847 Accuracy 0.7291
Epoch 19 Batch 450 Loss 1.1905 Accuracy 0.7279
Epoch 19 Batch 500 Loss 1.1967 Accuracy 0.7267
Epoch 19 Batch 550 Loss 1.2032 Accuracy 0.7255
Epoch 19 Batch 600 Loss 1.2076 Accuracy 0.7249
Epoch 19 Batch 650 Loss 1.2144 Accuracy 0.7238
Epoch 19 Batch 700 Loss 1.2202 Accuracy 0.7228
Epoch 19 Loss 1.2207 Accuracy 0.7227
Time taken for 1 epoch: 330.0889730453491 secs

Epoch 20 Batch 0 Loss 0.9899 Accuracy 0.7659
Epoch 20 Batch 50 Loss 1.0937 Accuracy 0.7445
Epoch 20 Batch 100 Loss 1.1030 Accuracy 0.7432
Epoch 20 Batch 150 Loss 1.1085 Accuracy 0.7418
Epoch 20 Batch 200 Loss 1.1173 Accuracy 0.7404
Epoch 20 Batch 250 Loss 1.1255 Accuracy 0.7387
Epoch 20 Batch 300 Loss 1.1306 Accuracy 0.7377
Epoch 20 Batch 350 Loss 1.1335 Accuracy 0.7372
Epoch 20 Batch 400 Loss 1.1396 Accuracy 0.7362
Epoch 20 Batch 450 Loss 1.1451 Accuracy 0.7350
Epoch 20 Batch 500 Loss 1.1503 Accuracy 0.7342
Epoch 20 Batch 550 Loss 1.1561 Accuracy 0.7330
Epoch 20 Batch 600 Loss 1.1630 Accuracy 0.7317
Epoch 20 Batch 650 Loss 1.1674 Accuracy 0.7310
Epoch 20 Batch 700 Loss 1.1748 Accuracy 0.7298
Saving checkpoint for epoch 20 at ./checkpoints/train/ckpt-4
Epoch 20 Loss 1.1756 Accuracy 0.7297
Time taken for 1 epoch: 330.0001685619354 secs


Evaluate

The following steps are used for evaluation:

  • Encode the input sentence using the Portuguese tokenizer (tokenizer_pt). Moreover, add the start and end token so the input is equivalent to what the model is trained with. This is the encoder input.
  • The decoder input is the start token == tokenizer_en.vocab_size.
  • Calculate the padding masks and the look ahead masks.
  • The decoder then outputs the predictions by looking at the encoder output and its own output (self-attention).
  • Select the last word and calculate the argmax of that.
  • Concatentate the predicted word to the decoder input as pass it to the decoder.
  • In this approach, the decoder predicts the next word based on the previous words it predicted.
def evaluate(inp_sentence):
  start_token = [tokenizer_pt.vocab_size]
  end_token = [tokenizer_pt.vocab_size + 1]

  # inp sentence is portuguese, hence adding the start and end token
  inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
  encoder_input = tf.expand_dims(inp_sentence, 0)

  # as the target is english, the first word to the transformer should be the
  # english start token.
  decoder_input = [tokenizer_en.vocab_size]
  output = tf.expand_dims(decoder_input, 0)

  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)

    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)

    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

    # return the result if the predicted_id is equal to the end token
    if predicted_id == tokenizer_en.vocab_size+1:
      return tf.squeeze(output, axis=0), attention_weights

    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0), attention_weights
def plot_attention_weights(attention, sentence, result, layer):
  fig = plt.figure(figsize=(16, 8))

  sentence = tokenizer_pt.encode(sentence)

  attention = tf.squeeze(attention[layer], axis=0)

  for head in range(attention.shape[0]):
    ax = fig.add_subplot(2, 4, head+1)

    # plot the attention weights
    ax.matshow(attention[head][:-1, :], cmap='viridis')

    fontdict = {'fontsize': 10}

    ax.set_xticks(range(len(sentence)+2))
    ax.set_yticks(range(len(result)))

    ax.set_ylim(len(result)-1.5, -0.5)

    ax.set_xticklabels(
        ['<start>']+[tokenizer_pt.decode([i]) for i in sentence]+['<end>'], 
        fontdict=fontdict, rotation=90)

    ax.set_yticklabels([tokenizer_en.decode([i]) for i in result 
                        if i < tokenizer_en.vocab_size], 
                       fontdict=fontdict)

    ax.set_xlabel('Head {}'.format(head+1))

  plt.tight_layout()
  plt.show()
def translate(sentence, plot=''):
  result, attention_weights = evaluate(sentence)

  predicted_sentence = tokenizer_en.decode([i for i in result 
                                            if i < tokenizer_en.vocab_size])  

  print('Input: {}'.format(sentence))
  print('Predicted translation: {}'.format(predicted_sentence))

  if plot:
    plot_attention_weights(attention_weights, sentence, result, plot)
translate("este é um problema que temos que resolver.")
print ("Real translation: this is a problem we have to solve .")
Input: este é um problema que temos que resolver.
Predicted translation: so this is a problem that we have to solve the global challenges .
Real translation: this is a problem we have to solve .

translate("os meus vizinhos ouviram sobre esta ideia.")
print ("Real translation: and my neighboring homes heard about this idea .")
Input: os meus vizinhos ouviram sobre esta ideia.
Predicted translation: my neighbors heard about this idea .
Real translation: and my neighboring homes heard about this idea .

translate("vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.")
print ("Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .")
Input: vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.
Predicted translation: so i 'm going to share with you a lot of magic stories that happened to be some very magical stuff that happened .
Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .

You can pass different layers and attention blocks of the decoder to the plot parameter.

translate("este é o primeiro livro que eu fiz.", plot='decoder_layer4_block2')
print ("Real translation: this is the first book i've ever done.")
Input: este é o primeiro livro que eu fiz.
Predicted translation: so this is the first book i made .

png

Real translation: this is the first book i've ever done.

Summary

In this tutorial, you learned about positional encoding, multi-head attention, the importance of masking and how to create a transformer.

Try using a different dataset to train the transformer. You can also create the base transformer or transformer XL by changing the hyperparameters above. You can also use the layers defined here to create BERT and train state of the art models. Futhermore, you can implement beam search to get better predictions.