Text generation using a RNN with eager execution

View on TensorFlow.org Run in Google Colab View source on Github

This tutorial demonstrates how to generate text using a character-based RNN. We will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.

This tutorial includes runnable code implemented using tf.keras and eager execution. The following is sample output when this tutorial is run with the default settings:

I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.

Marry, and will, my lord, to weep in such a one were prettiest;
Yet now I was adopted heir
Of the world's lamentable day,
To watch the next way with his father with his face?

The cause why then we are all resolved more sons.

O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,
And love and pale as any will to that word.

But how long have I heard the soul for this world,
And show his hands of life be proved to stand.

I say he look'd on, if I must be content
To stay him from the fatal of our country's bliss.
His lordship pluck'd from this sentence then for prey,
And then let us twain, being the moon,
were she such a case as fills m

While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:

  • The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.

  • The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.

  • As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.


Import TensorFlow and other libraries

import tensorflow as tf

import numpy as np
import os
import time

Download the Shakespeare dataset

Change the following line to run this code on your own data.

path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
1122304/1115394 [==============================] - 0s 0us/step

Read the data

First, let's look in the text.

text = open(path_to_file).read()
# length of text is the number of characters in it
print ('Length of text: {} characters'.format(len(text)))
Length of text: 1115394 characters
# Take a look at the first 1000 characters in text
First Citizen:
Before we proceed any further, hear me speak.

Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.

# The unique characters in the file
vocab = sorted(set(text))
print ('{} unique characters'.format(len(vocab)))
65 unique characters

Process the text

Vectorize the text

Before training, we need to map strings to a numerical representation. Create two lookup tables: one mapping characters to numbers, and another for numbers to characters.

# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

text_as_int = np.array([char2idx[c] for c in text])

Now we have an integer representation for each character. Notice that we mapped the character as indexes from 0 to len(unique).

for char,_ in zip(char2idx, range(20)):
    print('{:6s} ---> {:4d}'.format(repr(char), char2idx[char]))
'O'    --->   27
'Q'    --->   29
'I'    --->   21
'.'    --->    8
"'"    --->    5
'N'    --->   26
'J'    --->   22
'v'    --->   60
':'    --->   10
'G'    --->   19
'm'    --->   51
't'    --->   58
'S'    --->   31
'x'    --->   62
'E'    --->   17
'V'    --->   34
'i'    --->   47
'H'    --->   20
'&'    --->    4
'C'    --->   15
# Show how the first 13 characters from the text are mapped to integers
print ('{} ---- characters mapped to int ---- > {}'.format(text[:13], text_as_int[:13]))
First Citizen ---- characters mapped to int ---- > [18 47 56 57 58  1 15 47 58 47 64 43 52]

The prediction task

Given a character, or a sequence of characters, what is the most probable next character? This is the task we're training the model to perform. The input to the model will be a sequence of characters, and we train the model to predict the output—the following character at each time step.

Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?

Create training examples and targets

Divide the text into training examples and targets. Each training example will contain seq_length characters from the text. The corresponding targets contain the same length of text, except shifted one character to the right. For example, say seq_length is 4 and our text is "Hello", create one training example "Hell", and one target "ello".

Break the text into chunks of seq_length+1:

# The maximum length sentence we want for a single input in characters
seq_length = 100

# Create training examples / targets
chunks = tf.data.Dataset.from_tensor_slices(text_as_int).batch(seq_length+1, drop_remainder=True)

for item in chunks.take(5):
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'

Next, create the input and target texts from this chunk:

def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = chunks.map(split_input_target)

Let's print the first 10 values of the first example:

for input_example, target_example in  dataset.take(1):
  print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))
  print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))
Input data:  'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target data: 'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '

Each index of these vectors are processed as one time step. For the input at time step 0, we receive the character mapped to the number 18 and try to predict the character mapped to the number 47. At time step 1, do the same thing but consider the previous step in addition to the current character.

for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):
    print("Step {:4d}".format(i))
    print("  input: {} ({:s})".format(input_idx, repr(idx2char[input_idx])))
    print("  expected output: {} ({:s})".format(target_idx, repr(idx2char[target_idx])))
Step    0
  input: 18 ('F')
  expected output: 47 ('i')
Step    1
  input: 47 ('i')
  expected output: 56 ('r')
Step    2
  input: 56 ('r')
  expected output: 57 ('s')
Step    3
  input: 57 ('s')
  expected output: 58 ('t')
Step    4
  input: 58 ('t')
  expected output: 1 (' ')

Creating batches and shuffling them using tf.data

We use tf.data to chunk the text into sections. But before feeding this data into the model, we need to shuffle the data and pack it into batches.

# Batch size 

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences, 
# so it doesn't attempt to shuffle the entire sequence in memory. Instead, 
# it maintains a buffer in which it shuffles elements).

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

The Model

Implement the model

Use the tf.keras model cubclassing API to create the model and change it however we like. There are three layers used to define our model:

  • Embedding layer: a trainable lookup table that will map the numbers of each character to a high dimensional vector with embedding_dim dimensions;
  • GRU layer: a type of RNN with layer size = units. (You can also use a LSTM layer here.)
  • Dense layer with vocab_size cells.
class Model(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, units):
    super(Model, self).__init__()
    self.units = units

    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

    if tf.test.is_gpu_available():
      self.gru = tf.keras.layers.CuDNNGRU(self.units, 
      self.gru = tf.keras.layers.GRU(self.units, 

    self.fc = tf.keras.layers.Dense(vocab_size)
  def call(self, x):
    embedding = self.embedding(x)
    # output at every time step
    # output shape == (batch_size, seq_length, hidden_size) 
    output = self.gru(embedding)
    # The dense layer will output predictions for every time_steps(seq_length)
    # output shape after the dense layer == (seq_length * batch_size, vocab_size)
    prediction = self.fc(output)
    # states will be used to pass at every step to the model while training
    return prediction

Instantiate the model, optimizer, and the loss function

# Length of the vocabulary in chars
vocab_size = len(vocab)

# The embedding dimension 
embedding_dim = 256

# Number of RNN units
units = 1024

model = Model(vocab_size, embedding_dim, units)

We'll use Adam optimizer with default arguments and the softmax cross entropy as the loss function. This loss function is important because we're training to predict the next character, and the number of characters is a discrete number (similar to a classification problem).

# Using adam optimizer with default arguments
optimizer = tf.train.AdamOptimizer()

# Using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors
def loss_function(real, preds):
    return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)

Checkpoints (Object-based saving)

Use tf.train.Checkpoint to save the weights of the model after a couple of epochs.

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# Checkpoint instance
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

Train the model

Here, use a custom training loop with GradientTape. You can learn more about this approach by reading the eager execution guide.

  • First, initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.

  • Next, iterate over the dataset (batch by batch) and calculate the predictions and the hidden states associated with that input.

  • There are a lot of interesting things happening during training:

    • The model gets hidden state (initialized with 0), lets call that H0 and the first batch of input, lets call that I0.
    • The model then returns the predictions P1 and H1.
    • For the next batch of input, the model receives I1 and H1.
    • The interesting thing here is that we pass H1 to the model with I1 which is how the model learns. The context learned from batch to batch is contained in the hidden state.
    • Continue doing this until the dataset is exhausted, then start a new epoch and repeat the process.
  • After calculating the predictions, calculate the loss using the loss function defined above. Then calculate the gradients of the loss with respect to the model variables.

  • Finally, take a step in that direction with the help of the optimizer using the apply_gradients function.

Below is a diagram representing the process described above:

model.build(tf.TensorShape([BATCH_SIZE, seq_length]))
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        multiple                  16640     
gru (GRU)                    multiple                  3935232   
dense (Dense)                multiple                  66625     
Total params: 4,018,497
Trainable params: 4,018,497
Non-trainable params: 0
# Training step

for epoch in range(EPOCHS):
    start = time.time()
    # initializing the hidden state at the start of every epoch
    # initally hidden is None
    hidden = model.reset_states()
    for (batch, (inp, target)) in enumerate(dataset):
          with tf.GradientTape() as tape:
              # feeding the hidden state back into the model
              # This is the interesting step
              predictions = model(inp)
              loss = loss_function(target, predictions)
          grads = tape.gradient(loss, model.variables)
          optimizer.apply_gradients(zip(grads, model.variables))

          if batch % 100 == 0:
              print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,
    # saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))
    print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Epoch 1 Batch 0 Loss 4.1750
Epoch 1 Batch 100 Loss 2.3517
Epoch 1 Loss 2.1793
Time taken for 1 epoch 774.1733202934265 sec

Epoch 2 Batch 0 Loss 2.1502
Epoch 2 Batch 100 Loss 1.9351
Epoch 2 Loss 1.8191
Time taken for 1 epoch 766.1733911037445 sec

Epoch 3 Batch 0 Loss 1.8099
Epoch 3 Batch 100 Loss 1.6994
Epoch 3 Loss 1.6409
Time taken for 1 epoch 762.7873966693878 sec

Epoch 4 Batch 0 Loss 1.6119
Epoch 4 Batch 100 Loss 1.5967
Epoch 4 Loss 1.5099
Time taken for 1 epoch 765.6999187469482 sec

Epoch 5 Batch 0 Loss 1.4498
Epoch 5 Batch 100 Loss 1.4723
Epoch 5 Loss 1.4599
Time taken for 1 epoch 752.7974932193756 sec

Epoch 6 Batch 0 Loss 1.4366
Epoch 6 Batch 100 Loss 1.4062
Epoch 6 Loss 1.3669
Time taken for 1 epoch 754.976331949234 sec

Epoch 7 Batch 0 Loss 1.3398
Epoch 7 Batch 100 Loss 1.4046
Epoch 7 Loss 1.3731
Time taken for 1 epoch 758.0927352905273 sec

Epoch 8 Batch 0 Loss 1.2517
Epoch 8 Batch 100 Loss 1.3586
Epoch 8 Loss 1.3172
Time taken for 1 epoch 754.760913848877 sec

Epoch 9 Batch 0 Loss 1.2488
Epoch 9 Batch 100 Loss 1.3058
Epoch 9 Loss 1.2892
Time taken for 1 epoch 777.5504937171936 sec

Epoch 10 Batch 0 Loss 1.1952
Epoch 10 Batch 100 Loss 1.2515
Epoch 10 Loss 1.2856
Time taken for 1 epoch 788.7329194545746 sec

Epoch 11 Batch 0 Loss 1.1727
Epoch 11 Batch 100 Loss 1.2412
Epoch 11 Loss 1.2362
Time taken for 1 epoch 784.0863444805145 sec

Epoch 12 Batch 0 Loss 1.1200
Epoch 12 Batch 100 Loss 1.1896
Epoch 12 Loss 1.2063
Time taken for 1 epoch 778.597975730896 sec

Epoch 13 Batch 0 Loss 1.1069
Epoch 13 Batch 100 Loss 1.1846
Epoch 13 Loss 1.1936
Time taken for 1 epoch 766.9924449920654 sec

Epoch 14 Batch 0 Loss 1.0375
Epoch 14 Batch 100 Loss 1.1656
Epoch 14 Loss 1.1369
Time taken for 1 epoch 765.9739792346954 sec

Epoch 15 Batch 0 Loss 1.0096
Epoch 15 Batch 100 Loss 1.1050
Epoch 15 Loss 1.1035
Time taken for 1 epoch 759.5643520355225 sec

Epoch 16 Batch 0 Loss 0.9848
Epoch 16 Batch 100 Loss 1.0595
Epoch 16 Loss 1.0582
Time taken for 1 epoch 756.1685025691986 sec

Epoch 17 Batch 0 Loss 0.9424
Epoch 17 Batch 100 Loss 0.9984
Epoch 17 Loss 1.0181
Time taken for 1 epoch 780.668380022049 sec

Epoch 18 Batch 0 Loss 0.8714
Epoch 18 Batch 100 Loss 0.9500
Epoch 18 Loss 1.0319
Time taken for 1 epoch 775.13188123703 sec

Epoch 19 Batch 0 Loss 0.8481
Epoch 19 Batch 100 Loss 0.9382
Epoch 19 Loss 0.9522
Time taken for 1 epoch 764.6832077503204 sec

Epoch 20 Batch 0 Loss 0.8036
Epoch 20 Batch 100 Loss 0.9106
Epoch 20 Loss 0.9336
Time taken for 1 epoch 761.5604221820831 sec

Epoch 21 Batch 0 Loss 0.7523
Epoch 21 Batch 100 Loss 0.8460
Epoch 21 Loss 0.9178
Time taken for 1 epoch 913.5510909557343 sec

Epoch 22 Batch 0 Loss 0.7142
Epoch 22 Batch 100 Loss 0.8129
Epoch 22 Loss 0.8635
Time taken for 1 epoch 792.1691782474518 sec

Epoch 23 Batch 0 Loss 0.6787
Epoch 23 Batch 100 Loss 0.7924
Epoch 23 Loss 0.8688
Time taken for 1 epoch 1799.7459905147552 sec

Epoch 24 Batch 0 Loss 0.6637
Epoch 24 Batch 100 Loss 0.7601
Epoch 24 Loss 0.8224
Time taken for 1 epoch 1945.9425377845764 sec

Epoch 25 Batch 0 Loss 0.6215
Epoch 25 Batch 100 Loss 0.7543
Epoch 25 Loss 0.7729
Time taken for 1 epoch 753.2542717456818 sec

Epoch 26 Batch 0 Loss 0.6098
Epoch 26 Batch 100 Loss 0.7236
Epoch 26 Loss 0.7626
Time taken for 1 epoch 750.8798377513885 sec

Epoch 27 Batch 0 Loss 0.5824
Epoch 27 Batch 100 Loss 0.6841
Epoch 27 Loss 0.7852
Time taken for 1 epoch 754.4956197738647 sec

Epoch 28 Batch 0 Loss 0.5737
Epoch 28 Batch 100 Loss 0.6675
Epoch 28 Loss 0.7347
Time taken for 1 epoch 1589.552562713623 sec

Epoch 29 Batch 0 Loss 0.5481
Epoch 29 Batch 100 Loss 0.7012
Epoch 29 Loss 0.7248
Time taken for 1 epoch 2138.2978603839874 sec

Epoch 30 Batch 0 Loss 0.5630
Epoch 30 Batch 100 Loss 0.6865
Epoch 30 Loss 0.7099
Time taken for 1 epoch 1378.361043214798 sec

checkpoint.save(file_prefix = checkpoint_prefix)

Restore the latest checkpoint

The model only accepts a fixed batch size. To use the same weights and a different model, we need to rebuild the model and restore the weights from the checkpoint.

!ls {checkpoint_dir}
checkpoint          ckpt-4.index
ckpt-1.data-00000-of-00001  ckpt-5.data-00000-of-00001
ckpt-1.index            ckpt-5.index
ckpt-2.data-00000-of-00001  ckpt-6.data-00000-of-00001
ckpt-2.index            ckpt-6.index
ckpt-3.data-00000-of-00001  ckpt-7.data-00000-of-00001
ckpt-3.index            ckpt-7.index
model = Model(vocab_size, embedding_dim, units)

checkpoint = tf.train.Checkpoint(model=model)

model.build(tf.TensorShape([1, None]))

Generate text using our trained model

The following code block generates the text:

  • Start by choosing a start string, initializing the hidden state and setting the number of characters to generate.

  • Get the predictions using the start string and the hidden state.

  • Then, use a multinomial distribution to calculate the index of the predicted character—use this predicted character as our next input to the model.

  • The hidden state returned by the model is fed back into the model so that it now has more context, instead than only one word. After predicting the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.

Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing style.

# Evaluation step (generating text using the learned model)

# Number of characters to generate
num_generate = 1000

# You can change the start string to experiment
start_string = 'Q'

# Converting our start string to numbers (vectorizing) 
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)

# Empty string to store our results
text_generated = []

# Low temperatures results in more predictable text.
# Higher temperatures results in more surprising text.
# Experiment to find the best setting.
temperature = 1.0

# Here batch size == 1
for i in range(num_generate):
    predictions = model(input_eval)
    # remove the batch dimension
    predictions = tf.squeeze(predictions, 0)

    # using a multinomial distribution to predict the word returned by the model
    predictions = predictions / temperature
    predicted_id = tf.multinomial(predictions, num_samples=1)[-1,0].numpy()
    # We pass the predicted word as the next input to the model
    # along with the previous hidden state
    input_eval = tf.expand_dims([predicted_id], 0)

print (start_string + ''.join(text_generated))
Which one rosts and rotten, courtesy,
And satisfy froot madamity of Verona's; n your grace.

Some servenes of a kinderness of his proce.
This tiggh twelve long ago.

O discharger, and it join'd
With sovereigning Richmond, and am I lend.

Ay, if I know the better doth he start all talk of her observe,
And prefer you for my mishes, call them back.

Gaunt'st as sweet as a cockle,
Repair to the name unto the world,
And see horrors; seain her your honour!

Give me thy hand.
The wit is but a head for 't.

I do love to her hence,
And that's mine own,
That, if the least of Maria?
May of Camillo tackless cater?

Well, then, I thank you.

My best and like a press'd truer brother!'
With six on blood to thee,
And he's dead, here are they despair:
See thee utto rise and pluck'd thee by the
hangman. My entertainment,
In pestoly right:
Thy fortune I have many fear'd of mortals
How sall hope is in duty the poor I'll 

You can also experiment with a different start character, or try adding another RNN layer to improve the model's accuracy, or adjusting the temperature parameter to generate more or less random predictions.