![]() |
![]() |
![]() |
![]() |
This tutorial trains a Transformer model to translate Portuguese to English. This is an advanced example that assumes knowledge of text generation and attention.
The core idea behind the Transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. Transformer creates stacks of self-attention layers and is explained below in the sections Scaled dot product attention and Multi-head attention.
A transformer model handles variable-sized input using stacks of self-attention layers instead of RNNs or CNNs. This general architecture has a number of advantages:
- It make no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, StarCraft units).
- Layer outputs can be calculated in parallel, instead of a series like an RNN.
- Distant items can affect each other's output without passing through many RNN-steps, or convolution layers (see Scene Memory Transformer for example).
- It can learn long-range dependencies. This is a challenge in many sequence tasks.
The downsides of this architecture are:
- For a time-series, the output for a time-step is calculated from the entire history instead of only the inputs and current hidden-state. This may be less efficient.
- If the input does have a temporal/spatial relationship, like text, some positional encoding must be added or the model will effectively see a bag of words.
After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.
# Pin matplotlib version to 3.2.2 since in the latest version
# transformer.ipynb fails with the following error:
# https://stackoverflow.com/questions/62953704/valueerror-the-number-of-fixedlocator-locations-5-usually-from-a-call-to-set
pip install -q matplotlib==3.2.2
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
Setup input pipeline
Use TFDS to load the Portugese-English translation dataset from the TED Talks Open Translation Project.
This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']
Downloading and preparing dataset 124.94 MiB (download: 124.94 MiB, generated: Unknown size, total: 124.94 MiB) to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0... Dataset ted_hrlr_translate downloaded and prepared to /home/kbuilder/tensorflow_datasets/ted_hrlr_translate/pt_to_en/1.0.0. Subsequent calls will reuse this data.
Create a custom subwords tokenizer from the training dataset.
tokenizer_en = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
(en.numpy() for pt, en in train_examples), target_vocab_size=2**13)
tokenizer_pt = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
(pt.numpy() for pt, en in train_examples), target_vocab_size=2**13)
sample_string = 'Transformer is awesome.'
tokenized_string = tokenizer_en.encode(sample_string)
print ('Tokenized string is {}'.format(tokenized_string))
original_string = tokenizer_en.decode(tokenized_string)
print ('The original string: {}'.format(original_string))
assert original_string == sample_string
Tokenized string is [7915, 1248, 7946, 7194, 13, 2799, 7877] The original string: Transformer is awesome.
The tokenizer encodes the string by breaking it into subwords if the word is not in its dictionary.
for ts in tokenized_string:
print ('{} ----> {}'.format(ts, tokenizer_en.decode([ts])))
7915 ----> T 1248 ----> ran 7946 ----> s 7194 ----> former 13 ----> is 2799 ----> awesome 7877 ----> .
BUFFER_SIZE = 20000
BATCH_SIZE = 64
Add a start and end token to the input and target.
def encode(lang1, lang2):
lang1 = [tokenizer_pt.vocab_size] + tokenizer_pt.encode(
lang1.numpy()) + [tokenizer_pt.vocab_size+1]
lang2 = [tokenizer_en.vocab_size] + tokenizer_en.encode(
lang2.numpy()) + [tokenizer_en.vocab_size+1]
return lang1, lang2
You want to use Dataset.map
to apply this function to each element of the dataset. Dataset.map
runs in graph mode.
- Graph tensors do not have a value.
- In graph mode you can only use TensorFlow Ops and functions.
So you can't .map
this function directly: You need to wrap it in a tf.py_function
. The tf.py_function
will pass regular tensors (with a value and a .numpy()
method to access it), to the wrapped python function.
def tf_encode(pt, en):
result_pt, result_en = tf.py_function(encode, [pt, en], [tf.int64, tf.int64])
result_pt.set_shape([None])
result_en.set_shape([None])
return result_pt, result_en
MAX_LENGTH = 40
def filter_max_length(x, y, max_length=MAX_LENGTH):
return tf.logical_and(tf.size(x) <= max_length,
tf.size(y) <= max_length)
train_dataset = train_examples.map(tf_encode)
train_dataset = train_dataset.filter(filter_max_length)
# cache the dataset to memory to get a speedup while reading from it.
train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
val_dataset = val_examples.map(tf_encode)
val_dataset = val_dataset.filter(filter_max_length).padded_batch(BATCH_SIZE)
pt_batch, en_batch = next(iter(val_dataset))
pt_batch, en_batch
(<tf.Tensor: shape=(64, 38), dtype=int64, numpy= array([[8214, 342, 3032, ..., 0, 0, 0], [8214, 95, 198, ..., 0, 0, 0], [8214, 4479, 7990, ..., 0, 0, 0], ..., [8214, 584, 12, ..., 0, 0, 0], [8214, 59, 1548, ..., 0, 0, 0], [8214, 118, 34, ..., 0, 0, 0]])>, <tf.Tensor: shape=(64, 40), dtype=int64, numpy= array([[8087, 98, 25, ..., 0, 0, 0], [8087, 12, 20, ..., 0, 0, 0], [8087, 12, 5453, ..., 0, 0, 0], ..., [8087, 18, 2059, ..., 0, 0, 0], [8087, 16, 1436, ..., 0, 0, 0], [8087, 15, 57, ..., 0, 0, 0]])>)
Positional encoding
Since this model doesn't contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence.
The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of words in a sentence. So after adding the positional encoding, words will be closer to each other based on the similarity of their meaning and their position in the sentence, in the d-dimensional space.
See the notebook on positional encoding to learn more about it. The formula for calculating the positional encoding is as follows:
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
pos_encoding = positional_encoding(50, 512)
print (pos_encoding.shape)
plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('Depth')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()
(1, 50, 512)
Masking
Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. The mask indicates where pad value 0
is present: it outputs a 1
at those locations, and a 0
otherwise.
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
# add extra dimensions to add the padding
# to the attention logits.
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)
<tf.Tensor: shape=(3, 1, 1, 5), dtype=float32, numpy= array([[[[0., 0., 1., 1., 0.]]], [[[0., 0., 0., 1., 1.]]], [[[1., 1., 1., 0., 0.]]]], dtype=float32)>
The look-ahead mask is used to mask the future tokens in a sequence. In other words, the mask indicates which entries should not be used.
This means that to predict the third word, only the first and second word will be used. Similarly to predict the fourth word, only the first, second and the third word will be used and so on.
def create_look_ahead_mask(size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp
<tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[0., 1., 1.], [0., 0., 1.], [0., 0., 0.]], dtype=float32)>
Scaled dot product attention
The attention function used by the transformer takes three inputs: Q (query), K (key), V (value). The equation used to calculate the attention weights is:
The dot-product attention is scaled by a factor of square root of the depth. This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax.
For example, consider that Q
and K
have a mean of 0 and variance of 1. Their matrix multiplication will have a mean of 0 and variance of dk
. Hence, square root of dk
is used for scaling (and not any other number) because the matmul of Q
and K
should have a mean of 0 and variance of 1, and you get a gentler softmax.
The mask is multiplied with -1e9 (close to negative infinity). This is done because the mask is summed with the scaled matrix multiplication of Q and K and is applied immediately before a softmax. The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.
def scaled_dot_product_attention(q, k, v, mask):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
Args:
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output, attention_weights
"""
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
As the softmax normalization is done on K, its values decide the amount of importance given to Q.
The output represents the multiplication of the attention weights and the V (value) vector. This ensures that the words you want to focus on are kept as-is and the irrelevant words are flushed out.
def print_out(q, k, v):
temp_out, temp_attn = scaled_dot_product_attention(
q, k, v, None)
print ('Attention weights are:')
print (temp_attn)
print ('Output is:')
print (temp_out)
np.set_printoptions(suppress=True)
temp_k = tf.constant([[10,0,0],
[0,10,0],
[0,0,10],
[0,0,10]], dtype=tf.float32) # (4, 3)
temp_v = tf.constant([[ 1,0],
[ 10,0],
[ 100,5],
[1000,6]], dtype=tf.float32) # (4, 2)
# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32) Output is: tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)
# This query aligns with a repeated key (third and fourth),
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor([[0. 0. 0.5 0.5]], shape=(1, 4), dtype=float32) Output is: tf.Tensor([[550. 5.5]], shape=(1, 2), dtype=float32)
# This query aligns equally with the first and second key,
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor([[0.5 0.5 0. 0. ]], shape=(1, 4), dtype=float32) Output is: tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)
Pass all the queries together.
temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32) # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are: tf.Tensor( [[0. 0. 0.5 0.5] [0. 1. 0. 0. ] [0.5 0.5 0. 0. ]], shape=(3, 4), dtype=float32) Output is: tf.Tensor( [[550. 5.5] [ 10. 0. ] [ 5.5 0. ]], shape=(3, 2), dtype=float32)
Multi-head attention
Multi-head attention consists of four parts:
- Linear layers and split into heads.
- Scaled dot-product attention.
- Concatenation of heads.
- Final linear layer.
Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers and split up into multiple heads.
The scaled_dot_product_attention
defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step. The attention output for each head is then concatenated (using tf.transpose
, and tf.reshape
) and put through a final Dense
layer.
Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
Create a MultiHeadAttention
layer to try out. At each location in the sequence, y
, the MultiHeadAttention
runs all 8 attention heads across all other locations in the sequence, returning a new vector of the same length at each location.
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))
Point wise feed forward network
Point wise feed forward network consists of two fully-connected layers with a ReLU activation in between.
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape
TensorShape([64, 50, 512])
Encoder and decoder
The transformer model follows the same general pattern as a standard sequence to sequence with attention model.
- The input sentence is passed through
N
encoder layers that generates an output for each word/token in the sequence. - The decoder attends on the encoder's output and its own input (self-attention) to predict the next word.
Encoder layer
Each encoder layer consists of sublayers:
- Multi-head attention (with padding mask)
- Point wise feed forward networks.
Each of these sublayers has a residual connection around it followed by a layer normalization. Residual connections help in avoiding the vanishing gradient problem in deep networks.
The output of each sublayer is LayerNorm(x + Sublayer(x))
. The normalization is done on the d_model
(last) axis. There are N encoder layers in the transformer.
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
return out2
sample_encoder_layer = EncoderLayer(512, 8, 2048)
sample_encoder_layer_output = sample_encoder_layer(
tf.random.uniform((64, 43, 512)), False, None)
sample_encoder_layer_output.shape # (batch_size, input_seq_len, d_model)
TensorShape([64, 43, 512])
Decoder layer
Each decoder layer consists of sublayers:
- Masked multi-head attention (with look ahead mask and padding mask)
- Multi-head attention (with padding mask). V (value) and K (key) receive the encoder output as inputs. Q (query) receives the output from the masked multi-head attention sublayer.
- Point wise feed forward networks
Each of these sublayers has a residual connection around it followed by a layer normalization. The output of each sublayer is LayerNorm(x + Sublayer(x))
. The normalization is done on the d_model
(last) axis.
There are N decoder layers in the transformer.
As Q receives the output from decoder's first attention block, and K receives the encoder output, the attention weights represent the importance given to the decoder's input based on the encoder's output. In other words, the decoder predicts the next word by looking at the encoder output and self-attending to its own output. See the demonstration above in the scaled dot product attention section.
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
# enc_output.shape == (batch_size, input_seq_len, d_model)
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
attn2, attn_weights_block2 = self.mha2(
enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model)
ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)
return out3, attn_weights_block1, attn_weights_block2
sample_decoder_layer = DecoderLayer(512, 8, 2048)
sample_decoder_layer_output, _, _ = sample_decoder_layer(
tf.random.uniform((64, 50, 512)), sample_encoder_layer_output,
False, None, None)
sample_decoder_layer_output.shape # (batch_size, target_seq_len, d_model)
TensorShape([64, 50, 512])
Encoder
The Encoder
consists of:
- Input Embedding
- Positional Encoding
- N encoder layers
The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.
class Encoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding,
self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
# adding embedding and position encoding.
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x # (batch_size, input_seq_len, d_model)
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8,
dff=2048, input_vocab_size=8500,
maximum_position_encoding=10000)
temp_input = tf.random.uniform((64, 62), dtype=tf.int64, minval=0, maxval=200)
sample_encoder_output = sample_encoder(temp_input, training=False, mask=None)
print (sample_encoder_output.shape) # (batch_size, input_seq_len, d_model)
(64, 62, 512)
Decoder
The Decoder
consists of:
- Output Embedding
- Positional Encoding
- N decoder layers
The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.
class Decoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
maximum_position_encoding, rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layers[i](x, enc_output, training,
look_ahead_mask, padding_mask)
attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
# x.shape == (batch_size, target_seq_len, d_model)
return x, attention_weights
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8,
dff=2048, target_vocab_size=8000,
maximum_position_encoding=5000)
temp_input = tf.random.uniform((64, 26), dtype=tf.int64, minval=0, maxval=200)
output, attn = sample_decoder(temp_input,
enc_output=sample_encoder_output,
training=False,
look_ahead_mask=None,
padding_mask=None)
output.shape, attn['decoder_layer2_block2'].shape
(TensorShape([64, 26, 512]), TensorShape([64, 8, 26, 62]))
Create the Transformer
Transformer consists of the encoder, decoder and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.
class Transformer(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
target_vocab_size, pe_input, pe_target, rate=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(num_layers, d_model, num_heads, dff,
input_vocab_size, pe_input, rate)
self.decoder = Decoder(num_layers, d_model, num_heads, dff,
target_vocab_size, pe_target, rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inp, tar, training, enc_padding_mask,
look_ahead_mask, dec_padding_mask):
enc_output = self.encoder(inp, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model)
# dec_output.shape == (batch_size, tar_seq_len, d_model)
dec_output, attention_weights = self.decoder(
tar, enc_output, training, look_ahead_mask, dec_padding_mask)
final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
return final_output, attention_weights
sample_transformer = Transformer(
num_layers=2, d_model=512, num_heads=8, dff=2048,
input_vocab_size=8500, target_vocab_size=8000,
pe_input=10000, pe_target=6000)
temp_input = tf.random.uniform((64, 38), dtype=tf.int64, minval=0, maxval=200)
temp_target = tf.random.uniform((64, 36), dtype=tf.int64, minval=0, maxval=200)
fn_out, _ = sample_transformer(temp_input, temp_target, training=False,
enc_padding_mask=None,
look_ahead_mask=None,
dec_padding_mask=None)
fn_out.shape # (batch_size, tar_seq_len, target_vocab_size)
TensorShape([64, 36, 8000])
Set hyperparameters
To keep this example small and relatively fast, the values for num_layers, d_model, and dff have been reduced.
The values used in the base model of transformer were; num_layers=6, d_model = 512, dff = 2048. See the paper for all the other versions of the transformer.
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
input_vocab_size = tokenizer_pt.vocab_size + 2
target_vocab_size = tokenizer_en.vocab_size + 2
dropout_rate = 0.1
Optimizer
Use the Adam optimizer with a custom learning rate scheduler according to the formula in the paper.
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
epsilon=1e-9)
temp_learning_rate_schedule = CustomSchedule(d_model)
plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")
Text(0.5, 0, 'Train Step')
Loss and metrics
Since the target sequences are padded, it is important to apply a padding mask when calculating the loss.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_sum(loss_)/tf.reduce_sum(mask)
def accuracy_function(real, pred):
accuracies = tf.equal(real, tf.argmax(pred, axis=2))
mask = tf.math.logical_not(tf.math.equal(real, 0))
accuracies = tf.math.logical_and(mask, accuracies)
accuracies = tf.cast(accuracies, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
Training and checkpointing
transformer = Transformer(num_layers, d_model, num_heads, dff,
input_vocab_size, target_vocab_size,
pe_input=input_vocab_size,
pe_target=target_vocab_size,
rate=dropout_rate)
def create_masks(inp, tar):
# Encoder padding mask
enc_padding_mask = create_padding_mask(inp)
# Used in the 2nd attention block in the decoder.
# This padding mask is used to mask the encoder outputs.
dec_padding_mask = create_padding_mask(inp)
# Used in the 1st attention block in the decoder.
# It is used to pad and mask future tokens in the input received by
# the decoder.
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
dec_target_padding_mask = create_padding_mask(tar)
combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return enc_padding_mask, combined_mask, dec_padding_mask
Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n
epochs.
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(transformer=transformer,
optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
The target is divided into tar_inp and tar_real. tar_inp is passed as an input to the decoder. tar_real
is that same input shifted by 1: At each location in tar_input
, tar_real
contains the next token that should be predicted.
For example, sentence
= "SOS A lion in the jungle is sleeping EOS"
tar_inp
= "SOS A lion in the jungle is sleeping"
tar_real
= "A lion in the jungle is sleeping EOS"
The transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next.
During training this example uses teacher-forcing (like in the text generation tutorial). Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.
As the transformer predicts each word, self-attention allows it to look at the previous words in the input sequence to better predict the next word.
To prevent the model from peeking at the expected output the model uses a look-ahead mask.
EPOCHS = 20
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
with tf.GradientTape() as tape:
predictions, _ = transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = loss_function(tar_real, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(accuracy_function(tar_real, predictions))
Portuguese is used as the input language and English is the target language.
for epoch in range(EPOCHS):
start = time.time()
train_loss.reset_states()
train_accuracy.reset_states()
# inp -> portuguese, tar -> english
for (batch, (inp, tar)) in enumerate(train_dataset):
train_step(inp, tar)
if batch % 50 == 0:
print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
epoch + 1, batch, train_loss.result(), train_accuracy.result()))
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))
print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1,
train_loss.result(),
train_accuracy.result()))
print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
Epoch 1 Batch 0 Loss 9.0325 Accuracy 0.0000 Epoch 1 Batch 50 Loss 8.9631 Accuracy 0.0029 Epoch 1 Batch 100 Loss 8.8625 Accuracy 0.0285 Epoch 1 Batch 150 Loss 8.7535 Accuracy 0.0396 Epoch 1 Batch 200 Loss 8.6247 Accuracy 0.0465 Epoch 1 Batch 250 Loss 8.4712 Accuracy 0.0555 Epoch 1 Batch 300 Loss 8.2956 Accuracy 0.0660 Epoch 1 Batch 350 Loss 8.1095 Accuracy 0.0753 Epoch 1 Batch 400 Loss 7.9295 Accuracy 0.0823 Epoch 1 Batch 450 Loss 7.7644 Accuracy 0.0879 Epoch 1 Batch 500 Loss 7.6171 Accuracy 0.0927 Epoch 1 Batch 550 Loss 7.4840 Accuracy 0.0983 Epoch 1 Batch 600 Loss 7.3578 Accuracy 0.1048 Epoch 1 Batch 650 Loss 7.2364 Accuracy 0.1117 Epoch 1 Batch 700 Loss 7.1232 Accuracy 0.1185 Epoch 1 Loss 7.1188 Accuracy 0.1187 Time taken for 1 epoch: 54.0901415348053 secs Epoch 2 Batch 0 Loss 5.5828 Accuracy 0.1969 Epoch 2 Batch 50 Loss 5.4646 Accuracy 0.2218 Epoch 2 Batch 100 Loss 5.4122 Accuracy 0.2249 Epoch 2 Batch 150 Loss 5.3780 Accuracy 0.2278 Epoch 2 Batch 200 Loss 5.3342 Accuracy 0.2315 Epoch 2 Batch 250 Loss 5.2924 Accuracy 0.2352 Epoch 2 Batch 300 Loss 5.2508 Accuracy 0.2393 Epoch 2 Batch 350 Loss 5.2157 Accuracy 0.2424 Epoch 2 Batch 400 Loss 5.1816 Accuracy 0.2459 Epoch 2 Batch 450 Loss 5.1515 Accuracy 0.2492 Epoch 2 Batch 500 Loss 5.1203 Accuracy 0.2523 Epoch 2 Batch 550 Loss 5.0920 Accuracy 0.2552 Epoch 2 Batch 600 Loss 5.0661 Accuracy 0.2579 Epoch 2 Batch 650 Loss 5.0418 Accuracy 0.2604 Epoch 2 Batch 700 Loss 5.0186 Accuracy 0.2627 Epoch 2 Loss 5.0179 Accuracy 0.2628 Time taken for 1 epoch: 30.69964909553528 secs Epoch 3 Batch 0 Loss 4.6354 Accuracy 0.3023 Epoch 3 Batch 50 Loss 4.6234 Accuracy 0.3012 Epoch 3 Batch 100 Loss 4.6082 Accuracy 0.3021 Epoch 3 Batch 150 Loss 4.6073 Accuracy 0.3030 Epoch 3 Batch 200 Loss 4.5980 Accuracy 0.3043 Epoch 3 Batch 250 Loss 4.5887 Accuracy 0.3054 Epoch 3 Batch 300 Loss 4.5753 Accuracy 0.3069 Epoch 3 Batch 350 Loss 4.5617 Accuracy 0.3084 Epoch 3 Batch 400 Loss 4.5476 Accuracy 0.3099 Epoch 3 Batch 450 Loss 4.5342 Accuracy 0.3112 Epoch 3 Batch 500 Loss 4.5243 Accuracy 0.3122 Epoch 3 Batch 550 Loss 4.5141 Accuracy 0.3134 Epoch 3 Batch 600 Loss 4.5010 Accuracy 0.3148 Epoch 3 Batch 650 Loss 4.4917 Accuracy 0.3158 Epoch 3 Batch 700 Loss 4.4797 Accuracy 0.3170 Epoch 3 Loss 4.4795 Accuracy 0.3170 Time taken for 1 epoch: 30.95038342475891 secs Epoch 4 Batch 0 Loss 3.9974 Accuracy 0.3544 Epoch 4 Batch 50 Loss 4.2098 Accuracy 0.3431 Epoch 4 Batch 100 Loss 4.1827 Accuracy 0.3466 Epoch 4 Batch 150 Loss 4.1836 Accuracy 0.3471 Epoch 4 Batch 200 Loss 4.1799 Accuracy 0.3473 Epoch 4 Batch 250 Loss 4.1675 Accuracy 0.3486 Epoch 4 Batch 300 Loss 4.1581 Accuracy 0.3501 Epoch 4 Batch 350 Loss 4.1447 Accuracy 0.3516 Epoch 4 Batch 400 Loss 4.1289 Accuracy 0.3538 Epoch 4 Batch 450 Loss 4.1132 Accuracy 0.3559 Epoch 4 Batch 500 Loss 4.0977 Accuracy 0.3581 Epoch 4 Batch 550 Loss 4.0811 Accuracy 0.3601 Epoch 4 Batch 600 Loss 4.0648 Accuracy 0.3623 Epoch 4 Batch 650 Loss 4.0508 Accuracy 0.3644 Epoch 4 Batch 700 Loss 4.0378 Accuracy 0.3662 Epoch 4 Loss 4.0372 Accuracy 0.3663 Time taken for 1 epoch: 30.76306676864624 secs Epoch 5 Batch 0 Loss 3.7628 Accuracy 0.3862 Epoch 5 Batch 50 Loss 3.7014 Accuracy 0.4030 Epoch 5 Batch 100 Loss 3.6951 Accuracy 0.4051 Epoch 5 Batch 150 Loss 3.6715 Accuracy 0.4084 Epoch 5 Batch 200 Loss 3.6619 Accuracy 0.4105 Epoch 5 Batch 250 Loss 3.6462 Accuracy 0.4131 Epoch 5 Batch 300 Loss 3.6403 Accuracy 0.4145 Epoch 5 Batch 350 Loss 3.6238 Accuracy 0.4168 Epoch 5 Batch 400 Loss 3.6115 Accuracy 0.4186 Epoch 5 Batch 450 Loss 3.5990 Accuracy 0.4200 Epoch 5 Batch 500 Loss 3.5902 Accuracy 0.4215 Epoch 5 Batch 550 Loss 3.5806 Accuracy 0.4228 Epoch 5 Batch 600 Loss 3.5705 Accuracy 0.4240 Epoch 5 Batch 650 Loss 3.5612 Accuracy 0.4252 Epoch 5 Batch 700 Loss 3.5494 Accuracy 0.4267 Saving checkpoint for epoch 5 at ./checkpoints/train/ckpt-1 Epoch 5 Loss 3.5493 Accuracy 0.4267 Time taken for 1 epoch: 31.12636709213257 secs Epoch 6 Batch 0 Loss 3.2444 Accuracy 0.4605 Epoch 6 Batch 50 Loss 3.2155 Accuracy 0.4619 Epoch 6 Batch 100 Loss 3.2090 Accuracy 0.4631 Epoch 6 Batch 150 Loss 3.2024 Accuracy 0.4633 Epoch 6 Batch 200 Loss 3.1966 Accuracy 0.4648 Epoch 6 Batch 250 Loss 3.1929 Accuracy 0.4655 Epoch 6 Batch 300 Loss 3.1881 Accuracy 0.4664 Epoch 6 Batch 350 Loss 3.1832 Accuracy 0.4674 Epoch 6 Batch 400 Loss 3.1761 Accuracy 0.4685 Epoch 6 Batch 450 Loss 3.1699 Accuracy 0.4688 Epoch 6 Batch 500 Loss 3.1650 Accuracy 0.4696 Epoch 6 Batch 550 Loss 3.1592 Accuracy 0.4705 Epoch 6 Batch 600 Loss 3.1530 Accuracy 0.4717 Epoch 6 Batch 650 Loss 3.1471 Accuracy 0.4723 Epoch 6 Batch 700 Loss 3.1409 Accuracy 0.4731 Epoch 6 Loss 3.1406 Accuracy 0.4731 Time taken for 1 epoch: 30.89075517654419 secs Epoch 7 Batch 0 Loss 2.7938 Accuracy 0.5259 Epoch 7 Batch 50 Loss 2.8191 Accuracy 0.5065 Epoch 7 Batch 100 Loss 2.8050 Accuracy 0.5093 Epoch 7 Batch 150 Loss 2.8151 Accuracy 0.5085 Epoch 7 Batch 200 Loss 2.8113 Accuracy 0.5093 Epoch 7 Batch 250 Loss 2.8064 Accuracy 0.5101 Epoch 7 Batch 300 Loss 2.8001 Accuracy 0.5109 Epoch 7 Batch 350 Loss 2.7923 Accuracy 0.5123 Epoch 7 Batch 400 Loss 2.7791 Accuracy 0.5138 Epoch 7 Batch 450 Loss 2.7733 Accuracy 0.5146 Epoch 7 Batch 500 Loss 2.7688 Accuracy 0.5154 Epoch 7 Batch 550 Loss 2.7602 Accuracy 0.5166 Epoch 7 Batch 600 Loss 2.7542 Accuracy 0.5175 Epoch 7 Batch 650 Loss 2.7497 Accuracy 0.5184 Epoch 7 Batch 700 Loss 2.7439 Accuracy 0.5193 Epoch 7 Loss 2.7436 Accuracy 0.5194 Time taken for 1 epoch: 31.79271173477173 secs Epoch 8 Batch 0 Loss 2.5171 Accuracy 0.5501 Epoch 8 Batch 50 Loss 2.4147 Accuracy 0.5562 Epoch 8 Batch 100 Loss 2.4221 Accuracy 0.5555 Epoch 8 Batch 150 Loss 2.4131 Accuracy 0.5580 Epoch 8 Batch 200 Loss 2.4173 Accuracy 0.5570 Epoch 8 Batch 250 Loss 2.4214 Accuracy 0.5569 Epoch 8 Batch 300 Loss 2.4305 Accuracy 0.5555 Epoch 8 Batch 350 Loss 2.4305 Accuracy 0.5555 Epoch 8 Batch 400 Loss 2.4273 Accuracy 0.5560 Epoch 8 Batch 450 Loss 2.4261 Accuracy 0.5563 Epoch 8 Batch 500 Loss 2.4245 Accuracy 0.5567 Epoch 8 Batch 550 Loss 2.4233 Accuracy 0.5572 Epoch 8 Batch 600 Loss 2.4182 Accuracy 0.5582 Epoch 8 Batch 650 Loss 2.4162 Accuracy 0.5585 Epoch 8 Batch 700 Loss 2.4147 Accuracy 0.5589 Epoch 8 Loss 2.4139 Accuracy 0.5590 Time taken for 1 epoch: 31.114221572875977 secs Epoch 9 Batch 0 Loss 2.0796 Accuracy 0.6034 Epoch 9 Batch 50 Loss 2.1062 Accuracy 0.5978 Epoch 9 Batch 100 Loss 2.1268 Accuracy 0.5934 Epoch 9 Batch 150 Loss 2.1465 Accuracy 0.5914 Epoch 9 Batch 200 Loss 2.1525 Accuracy 0.5908 Epoch 9 Batch 250 Loss 2.1578 Accuracy 0.5902 Epoch 9 Batch 300 Loss 2.1590 Accuracy 0.5899 Epoch 9 Batch 350 Loss 2.1592 Accuracy 0.5898 Epoch 9 Batch 400 Loss 2.1598 Accuracy 0.5899 Epoch 9 Batch 450 Loss 2.1614 Accuracy 0.5898 Epoch 9 Batch 500 Loss 2.1632 Accuracy 0.5897 Epoch 9 Batch 550 Loss 2.1625 Accuracy 0.5899 Epoch 9 Batch 600 Loss 2.1655 Accuracy 0.5897 Epoch 9 Batch 650 Loss 2.1659 Accuracy 0.5896 Epoch 9 Batch 700 Loss 2.1671 Accuracy 0.5896 Epoch 9 Loss 2.1671 Accuracy 0.5896 Time taken for 1 epoch: 31.18158221244812 secs Epoch 10 Batch 0 Loss 1.8438 Accuracy 0.6431 Epoch 10 Batch 50 Loss 1.9301 Accuracy 0.6210 Epoch 10 Batch 100 Loss 1.9432 Accuracy 0.6182 Epoch 10 Batch 150 Loss 1.9503 Accuracy 0.6175 Epoch 10 Batch 200 Loss 1.9521 Accuracy 0.6180 Epoch 10 Batch 250 Loss 1.9540 Accuracy 0.6173 Epoch 10 Batch 300 Loss 1.9551 Accuracy 0.6174 Epoch 10 Batch 350 Loss 1.9615 Accuracy 0.6167 Epoch 10 Batch 400 Loss 1.9619 Accuracy 0.6165 Epoch 10 Batch 450 Loss 1.9636 Accuracy 0.6167 Epoch 10 Batch 500 Loss 1.9685 Accuracy 0.6160 Epoch 10 Batch 550 Loss 1.9700 Accuracy 0.6159 Epoch 10 Batch 600 Loss 1.9750 Accuracy 0.6153 Epoch 10 Batch 650 Loss 1.9787 Accuracy 0.6150 Epoch 10 Batch 700 Loss 1.9803 Accuracy 0.6147 Saving checkpoint for epoch 10 at ./checkpoints/train/ckpt-2 Epoch 10 Loss 1.9804 Accuracy 0.6147 Time taken for 1 epoch: 31.05565071105957 secs Epoch 11 Batch 0 Loss 1.6855 Accuracy 0.6649 Epoch 11 Batch 50 Loss 1.7798 Accuracy 0.6409 Epoch 11 Batch 100 Loss 1.7849 Accuracy 0.6394 Epoch 11 Batch 150 Loss 1.7887 Accuracy 0.6387 Epoch 11 Batch 200 Loss 1.7986 Accuracy 0.6376 Epoch 11 Batch 250 Loss 1.8039 Accuracy 0.6371 Epoch 11 Batch 300 Loss 1.8087 Accuracy 0.6361 Epoch 11 Batch 350 Loss 1.8122 Accuracy 0.6354 Epoch 11 Batch 400 Loss 1.8157 Accuracy 0.6352 Epoch 11 Batch 450 Loss 1.8157 Accuracy 0.6352 Epoch 11 Batch 500 Loss 1.8214 Accuracy 0.6344 Epoch 11 Batch 550 Loss 1.8249 Accuracy 0.6340 Epoch 11 Batch 600 Loss 1.8275 Accuracy 0.6339 Epoch 11 Batch 650 Loss 1.8327 Accuracy 0.6333 Epoch 11 Batch 700 Loss 1.8357 Accuracy 0.6329 Epoch 11 Loss 1.8357 Accuracy 0.6329 Time taken for 1 epoch: 31.027227878570557 secs Epoch 12 Batch 0 Loss 1.6143 Accuracy 0.6664 Epoch 12 Batch 50 Loss 1.6444 Accuracy 0.6597 Epoch 12 Batch 100 Loss 1.6483 Accuracy 0.6591 Epoch 12 Batch 150 Loss 1.6596 Accuracy 0.6571 Epoch 12 Batch 200 Loss 1.6646 Accuracy 0.6567 Epoch 12 Batch 250 Loss 1.6728 Accuracy 0.6555 Epoch 12 Batch 300 Loss 1.6742 Accuracy 0.6554 Epoch 12 Batch 350 Loss 1.6771 Accuracy 0.6547 Epoch 12 Batch 400 Loss 1.6812 Accuracy 0.6543 Epoch 12 Batch 450 Loss 1.6868 Accuracy 0.6532 Epoch 12 Batch 500 Loss 1.6895 Accuracy 0.6529 Epoch 12 Batch 550 Loss 1.6949 Accuracy 0.6522 Epoch 12 Batch 600 Loss 1.7010 Accuracy 0.6515 Epoch 12 Batch 650 Loss 1.7086 Accuracy 0.6504 Epoch 12 Batch 700 Loss 1.7139 Accuracy 0.6499 Epoch 12 Loss 1.7141 Accuracy 0.6499 Time taken for 1 epoch: 30.62126898765564 secs Epoch 13 Batch 0 Loss 1.4501 Accuracy 0.6824 Epoch 13 Batch 50 Loss 1.5419 Accuracy 0.6747 Epoch 13 Batch 100 Loss 1.5441 Accuracy 0.6730 Epoch 13 Batch 150 Loss 1.5583 Accuracy 0.6700 Epoch 13 Batch 200 Loss 1.5641 Accuracy 0.6695 Epoch 13 Batch 250 Loss 1.5682 Accuracy 0.6690 Epoch 13 Batch 300 Loss 1.5721 Accuracy 0.6683 Epoch 13 Batch 350 Loss 1.5766 Accuracy 0.6680 Epoch 13 Batch 400 Loss 1.5813 Accuracy 0.6676 Epoch 13 Batch 450 Loss 1.5896 Accuracy 0.6664 Epoch 13 Batch 500 Loss 1.5922 Accuracy 0.6661 Epoch 13 Batch 550 Loss 1.5978 Accuracy 0.6654 Epoch 13 Batch 600 Loss 1.6030 Accuracy 0.6646 Epoch 13 Batch 650 Loss 1.6082 Accuracy 0.6639 Epoch 13 Batch 700 Loss 1.6134 Accuracy 0.6633 Epoch 13 Loss 1.6134 Accuracy 0.6634 Time taken for 1 epoch: 30.817842721939087 secs Epoch 14 Batch 0 Loss 1.3038 Accuracy 0.7251 Epoch 14 Batch 50 Loss 1.4288 Accuracy 0.6897 Epoch 14 Batch 100 Loss 1.4507 Accuracy 0.6881 Epoch 14 Batch 150 Loss 1.4641 Accuracy 0.6851 Epoch 14 Batch 200 Loss 1.4735 Accuracy 0.6847 Epoch 14 Batch 250 Loss 1.4773 Accuracy 0.6839 Epoch 14 Batch 300 Loss 1.4825 Accuracy 0.6828 Epoch 14 Batch 350 Loss 1.4890 Accuracy 0.6819 Epoch 14 Batch 400 Loss 1.4937 Accuracy 0.6808 Epoch 14 Batch 450 Loss 1.4998 Accuracy 0.6798 Epoch 14 Batch 500 Loss 1.5044 Accuracy 0.6792 Epoch 14 Batch 550 Loss 1.5087 Accuracy 0.6785 Epoch 14 Batch 600 Loss 1.5150 Accuracy 0.6773 Epoch 14 Batch 650 Loss 1.5209 Accuracy 0.6766 Epoch 14 Batch 700 Loss 1.5267 Accuracy 0.6759 Epoch 14 Loss 1.5264 Accuracy 0.6760 Time taken for 1 epoch: 30.76487922668457 secs Epoch 15 Batch 0 Loss 1.3602 Accuracy 0.7020 Epoch 15 Batch 50 Loss 1.3786 Accuracy 0.6978 Epoch 15 Batch 100 Loss 1.3675 Accuracy 0.6994 Epoch 15 Batch 150 Loss 1.3881 Accuracy 0.6961 Epoch 15 Batch 200 Loss 1.3986 Accuracy 0.6947 Epoch 15 Batch 250 Loss 1.4042 Accuracy 0.6937 Epoch 15 Batch 300 Loss 1.4095 Accuracy 0.6929 Epoch 15 Batch 350 Loss 1.4175 Accuracy 0.6917 Epoch 15 Batch 400 Loss 1.4202 Accuracy 0.6913 Epoch 15 Batch 450 Loss 1.4256 Accuracy 0.6906 Epoch 15 Batch 500 Loss 1.4271 Accuracy 0.6904 Epoch 15 Batch 550 Loss 1.4337 Accuracy 0.6895 Epoch 15 Batch 600 Loss 1.4401 Accuracy 0.6885 Epoch 15 Batch 650 Loss 1.4460 Accuracy 0.6876 Epoch 15 Batch 700 Loss 1.4515 Accuracy 0.6868 Saving checkpoint for epoch 15 at ./checkpoints/train/ckpt-3 Epoch 15 Loss 1.4519 Accuracy 0.6868 Time taken for 1 epoch: 30.835017681121826 secs Epoch 16 Batch 0 Loss 1.1559 Accuracy 0.7455 Epoch 16 Batch 50 Loss 1.2853 Accuracy 0.7131 Epoch 16 Batch 100 Loss 1.3070 Accuracy 0.7094 Epoch 16 Batch 150 Loss 1.3140 Accuracy 0.7080 Epoch 16 Batch 200 Loss 1.3281 Accuracy 0.7059 Epoch 16 Batch 250 Loss 1.3378 Accuracy 0.7040 Epoch 16 Batch 300 Loss 1.3457 Accuracy 0.7029 Epoch 16 Batch 350 Loss 1.3470 Accuracy 0.7029 Epoch 16 Batch 400 Loss 1.3514 Accuracy 0.7023 Epoch 16 Batch 450 Loss 1.3575 Accuracy 0.7011 Epoch 16 Batch 500 Loss 1.3623 Accuracy 0.7004 Epoch 16 Batch 550 Loss 1.3688 Accuracy 0.6993 Epoch 16 Batch 600 Loss 1.3738 Accuracy 0.6985 Epoch 16 Batch 650 Loss 1.3789 Accuracy 0.6978 Epoch 16 Batch 700 Loss 1.3835 Accuracy 0.6972 Epoch 16 Loss 1.3842 Accuracy 0.6970 Time taken for 1 epoch: 30.86835026741028 secs Epoch 17 Batch 0 Loss 1.1976 Accuracy 0.7252 Epoch 17 Batch 50 Loss 1.2320 Accuracy 0.7229 Epoch 17 Batch 100 Loss 1.2459 Accuracy 0.7198 Epoch 17 Batch 150 Loss 1.2615 Accuracy 0.7170 Epoch 17 Batch 200 Loss 1.2671 Accuracy 0.7155 Epoch 17 Batch 250 Loss 1.2781 Accuracy 0.7137 Epoch 17 Batch 300 Loss 1.2837 Accuracy 0.7124 Epoch 17 Batch 350 Loss 1.2856 Accuracy 0.7121 Epoch 17 Batch 400 Loss 1.2926 Accuracy 0.7109 Epoch 17 Batch 450 Loss 1.2976 Accuracy 0.7103 Epoch 17 Batch 500 Loss 1.3044 Accuracy 0.7092 Epoch 17 Batch 550 Loss 1.3103 Accuracy 0.7080 Epoch 17 Batch 600 Loss 1.3158 Accuracy 0.7072 Epoch 17 Batch 650 Loss 1.3188 Accuracy 0.7069 Epoch 17 Batch 700 Loss 1.3237 Accuracy 0.7062 Epoch 17 Loss 1.3240 Accuracy 0.7062 Time taken for 1 epoch: 31.70491647720337 secs Epoch 18 Batch 0 Loss 1.3016 Accuracy 0.7197 Epoch 18 Batch 50 Loss 1.1798 Accuracy 0.7305 Epoch 18 Batch 100 Loss 1.1901 Accuracy 0.7291 Epoch 18 Batch 150 Loss 1.1978 Accuracy 0.7277 Epoch 18 Batch 200 Loss 1.2066 Accuracy 0.7257 Epoch 18 Batch 250 Loss 1.2141 Accuracy 0.7243 Epoch 18 Batch 300 Loss 1.2225 Accuracy 0.7227 Epoch 18 Batch 350 Loss 1.2288 Accuracy 0.7215 Epoch 18 Batch 400 Loss 1.2353 Accuracy 0.7203 Epoch 18 Batch 450 Loss 1.2398 Accuracy 0.7195 Epoch 18 Batch 500 Loss 1.2441 Accuracy 0.7186 Epoch 18 Batch 550 Loss 1.2516 Accuracy 0.7174 Epoch 18 Batch 600 Loss 1.2581 Accuracy 0.7164 Epoch 18 Batch 650 Loss 1.2648 Accuracy 0.7154 Epoch 18 Batch 700 Loss 1.2693 Accuracy 0.7148 Epoch 18 Loss 1.2694 Accuracy 0.7148 Time taken for 1 epoch: 30.78923749923706 secs Epoch 19 Batch 0 Loss 1.0897 Accuracy 0.7513 Epoch 19 Batch 50 Loss 1.1161 Accuracy 0.7400 Epoch 19 Batch 100 Loss 1.1371 Accuracy 0.7369 Epoch 19 Batch 150 Loss 1.1537 Accuracy 0.7339 Epoch 19 Batch 200 Loss 1.1645 Accuracy 0.7322 Epoch 19 Batch 250 Loss 1.1728 Accuracy 0.7307 Epoch 19 Batch 300 Loss 1.1812 Accuracy 0.7293 Epoch 19 Batch 350 Loss 1.1851 Accuracy 0.7283 Epoch 19 Batch 400 Loss 1.1890 Accuracy 0.7277 Epoch 19 Batch 450 Loss 1.1922 Accuracy 0.7273 Epoch 19 Batch 500 Loss 1.1964 Accuracy 0.7265 Epoch 19 Batch 550 Loss 1.2038 Accuracy 0.7251 Epoch 19 Batch 600 Loss 1.2107 Accuracy 0.7238 Epoch 19 Batch 650 Loss 1.2152 Accuracy 0.7232 Epoch 19 Batch 700 Loss 1.2193 Accuracy 0.7224 Epoch 19 Loss 1.2199 Accuracy 0.7223 Time taken for 1 epoch: 31.050918340682983 secs Epoch 20 Batch 0 Loss 1.0248 Accuracy 0.7491 Epoch 20 Batch 50 Loss 1.1020 Accuracy 0.7435 Epoch 20 Batch 100 Loss 1.1042 Accuracy 0.7420 Epoch 20 Batch 150 Loss 1.1102 Accuracy 0.7407 Epoch 20 Batch 200 Loss 1.1199 Accuracy 0.7385 Epoch 20 Batch 250 Loss 1.1271 Accuracy 0.7371 Epoch 20 Batch 300 Loss 1.1352 Accuracy 0.7358 Epoch 20 Batch 350 Loss 1.1401 Accuracy 0.7349 Epoch 20 Batch 400 Loss 1.1419 Accuracy 0.7348 Epoch 20 Batch 450 Loss 1.1476 Accuracy 0.7338 Epoch 20 Batch 500 Loss 1.1544 Accuracy 0.7325 Epoch 20 Batch 550 Loss 1.1585 Accuracy 0.7320 Epoch 20 Batch 600 Loss 1.1650 Accuracy 0.7309 Epoch 20 Batch 650 Loss 1.1699 Accuracy 0.7301 Epoch 20 Batch 700 Loss 1.1764 Accuracy 0.7291 Saving checkpoint for epoch 20 at ./checkpoints/train/ckpt-4 Epoch 20 Loss 1.1765 Accuracy 0.7290 Time taken for 1 epoch: 30.766892671585083 secs
Evaluate
The following steps are used for evaluation:
- Encode the input sentence using the Portuguese tokenizer (
tokenizer_pt
). Moreover, add the start and end token so the input is equivalent to what the model is trained with. This is the encoder input. - The decoder input is the
start token == tokenizer_en.vocab_size
. - Calculate the padding masks and the look ahead masks.
- The
decoder
then outputs the predictions by looking at theencoder output
and its own output (self-attention). - Select the last word and calculate the argmax of that.
- Concatentate the predicted word to the decoder input as pass it to the decoder.
- In this approach, the decoder predicts the next word based on the previous words it predicted.
def evaluate(inp_sentence):
start_token = [tokenizer_pt.vocab_size]
end_token = [tokenizer_pt.vocab_size + 1]
# inp sentence is portuguese, hence adding the start and end token
inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
encoder_input = tf.expand_dims(inp_sentence, 0)
# as the target is english, the first word to the transformer should be the
# english start token.
decoder_input = [tokenizer_en.vocab_size]
output = tf.expand_dims(decoder_input, 0)
for i in range(MAX_LENGTH):
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
encoder_input, output)
# predictions.shape == (batch_size, seq_len, vocab_size)
predictions, attention_weights = transformer(encoder_input,
output,
False,
enc_padding_mask,
combined_mask,
dec_padding_mask)
# select the last word from the seq_len dimension
predictions = predictions[: ,-1:, :] # (batch_size, 1, vocab_size)
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# return the result if the predicted_id is equal to the end token
if predicted_id == tokenizer_en.vocab_size+1:
return tf.squeeze(output, axis=0), attention_weights
# concatentate the predicted_id to the output which is given to the decoder
# as its input.
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0), attention_weights
def plot_attention_weights(attention, sentence, result, layer):
fig = plt.figure(figsize=(16, 8))
sentence = tokenizer_pt.encode(sentence)
attention = tf.squeeze(attention[layer], axis=0)
for head in range(attention.shape[0]):
ax = fig.add_subplot(2, 4, head+1)
# plot the attention weights
ax.matshow(attention[head][:-1, :], cmap='viridis')
fontdict = {'fontsize': 10}
ax.set_xticks(range(len(sentence)+2))
ax.set_yticks(range(len(result)))
ax.set_ylim(len(result)-1.5, -0.5)
ax.set_xticklabels(
['<start>']+[tokenizer_pt.decode([i]) for i in sentence]+['<end>'],
fontdict=fontdict, rotation=90)
ax.set_yticklabels([tokenizer_en.decode([i]) for i in result
if i < tokenizer_en.vocab_size],
fontdict=fontdict)
ax.set_xlabel('Head {}'.format(head+1))
plt.tight_layout()
plt.show()
def translate(sentence, plot=''):
result, attention_weights = evaluate(sentence)
predicted_sentence = tokenizer_en.decode([i for i in result
if i < tokenizer_en.vocab_size])
print('Input: {}'.format(sentence))
print('Predicted translation: {}'.format(predicted_sentence))
if plot:
plot_attention_weights(attention_weights, sentence, result, plot)
translate("este é um problema que temos que resolver.")
print ("Real translation: this is a problem we have to solve .")
Input: este é um problema que temos que resolver. Predicted translation: this is a problem that we have to solve one . Real translation: this is a problem we have to solve .
translate("os meus vizinhos ouviram sobre esta ideia.")
print ("Real translation: and my neighboring homes heard about this idea .")
Input: os meus vizinhos ouviram sobre esta ideia. Predicted translation: my neighbors heard about this idea . Real translation: and my neighboring homes heard about this idea .
translate("vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.")
print ("Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .")
Input: vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram. Predicted translation: so i 'm going to really quickly share with you some magic stories that happened to happen . Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .
You can pass different layers and attention blocks of the decoder to the plot
parameter.
translate("este é o primeiro livro que eu fiz.", plot='decoder_layer4_block2')
print ("Real translation: this is the first book i've ever done.")
Input: este é o primeiro livro que eu fiz. Predicted translation: this is the first book i made .
Real translation: this is the first book i've ever done.
Summary
In this tutorial, you learned about positional encoding, multi-head attention, the importance of masking and how to create a transformer.
Try using a different dataset to train the transformer. You can also create the base transformer or transformer XL by changing the hyperparameters above. You can also use the layers defined here to create BERT and train state of the art models. Futhermore, you can implement beam search to get better predictions.