Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

Text generation with an RNN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.

This tutorial includes runnable code implemented using tf.keras and eager execution. The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt "Q":

QUEENE:
I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.

BISHOP OF ELY:
Marry, and will, my lord, to weep in such a one were prettiest;
Yet now I was adopted heir
Of the world's lamentable day,
To watch the next way with his father with his face?

ESCALUS:
The cause why then we are all resolved more sons.

VOLUMNIA:
O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,
And love and pale as any will to that word.

QUEEN ELIZABETH:
But how long have I heard the soul for this world,
And show his hands of life be proved to stand.

PETRUCHIO:
I say he look'd on, if I must be content
To stay him from the fatal of our country's bliss.
His lordship pluck'd from this sentence then for prey,
And then let us twain, being the moon,
were she such a case as fills m

While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:

  • The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.

  • The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.

  • As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.

Setup

Import TensorFlow and other libraries

import tensorflow as tf

import numpy as np
import os
import time

Download the Shakespeare dataset

Change the following line to run this code on your own data.

path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
1122304/1115394 [==============================] - 0s 0us/step
1130496/1115394 [==============================] - 0s 0us/step

Read the data

First, look in the text:

# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')
Length of text: 1115394 characters
# Take a look at the first 250 characters in text
print(text[:250])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.
# The unique characters in the file
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')
65 unique characters

Process the text

Vectorize the text

Before training, you need to convert the strings to a numerical representation.

The tf.keras.layers.StringLookup layer can convert each character into a numeric ID. It just needs the text to be split into tokens first.

example_texts = ['abcdefg', 'xyz']

chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

Now create the tf.keras.layers.StringLookup layer:

ids_from_chars = tf.keras.layers.StringLookup(
    vocabulary=list(vocab), mask_token=None)

It converts form tokens to character IDs:

ids = ids_from_chars(chars)
ids
<tf.RaggedTensor [[40, 41, 42, 43, 44, 45, 46], [63, 64, 65]]>

Since the goal of this tutorial is to generate text, it will also be important to invert this representation and recover human-readable strings from it. For this you can use tf.keras.layers.StringLookup(..., invert=True).

chars_from_ids = tf.keras.layers.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

This layer recovers the characters from the vectors of IDs, and returns them as a tf.RaggedTensor of characters:

chars = chars_from_ids(ids)
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

You can tf.strings.reduce_join to join the characters back into strings.

tf.strings.reduce_join(chars, axis=-1).numpy()
array([b'abcdefg', b'xyz'], dtype=object)
def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

The prediction task

Given a character, or a sequence of characters, what is the most probable next character? This is the task you're training the model to perform. The input to the model will be a sequence of characters, and you train the model to predict the output—the following character at each time step.

Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?

Create training examples and targets

Next divide the text into example sequences. Each input sequence will contain seq_length characters from the text.

For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.

So break the text into chunks of seq_length+1. For example, say seq_length is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".

To do this first use the tf.data.Dataset.from_tensor_slices function to convert the text vector into a stream of character indices.

all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids
<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([19, 48, 57, ..., 46,  9,  1])>
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))
F
i
r
s
t
 
C
i
t
i
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)

The batch method lets you easily convert these individual characters to sequences of the desired size.

sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

for seq in sequences.take(1):
  print(chars_from_ids(seq))
tf.Tensor(
[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'
 b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'
 b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'
 b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'
 b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e'
 b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i'
 b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y'
 b'o' b'u' b' '], shape=(101,), dtype=string)

It's easier to see what this is doing if you join the tokens back into strings:

for seq in sequences.take(5):
  print(text_from_ids(seq).numpy())
b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'

For training you'll need a dataset of (input, label) pairs. Where input and label are sequences. At each time step the input is the current character and the label is the next character.

Here's a function that takes a sequence as input, duplicates, and shifts it to align the input and label for each timestep:

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text
split_input_target(list("Tensorflow"))
(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'],
 ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())
Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '

Create training batches

You used tf.data to split the text into manageable sequences. But before feeding this data into the model, you need to shuffle the data and pack it into batches.

# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset
<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

Build The Model

This section defines the model as a keras.Model subclass (For details see Making new Layers and Models via subclassing).

This model has three layers:

  • tf.keras.layers.Embedding: The input layer. A trainable lookup table that will map each character-ID to a vector with embedding_dim dimensions;
  • tf.keras.layers.GRU: A type of RNN with size units=rnn_units (You can also use an LSTM layer here.)
  • tf.keras.layers.Dense: The output layer, with vocab_size outputs. It outputs one logit for each character in the vocabulary. These are the log-likelihood of each character according to the model.
# Length of the vocabulary in chars
vocab_size = len(vocab)

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024
class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x
model = MyModel(
    # Be sure the vocabulary size matches the `StringLookup` layers.
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character:

A drawing of the data passing through the model

Try the model

Now run the model to see that it behaves as expected.

First check the shape of the output:

for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 66) # (batch_size, sequence_length, vocab_size)

In the above example the sequence length of the input is 100 but the model can be run on inputs of any length:

model.summary()
Model: "my_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       multiple                  16896     
                                                                 
 gru (GRU)                   multiple                  3938304   
                                                                 
 dense (Dense)               multiple                  67650     
                                                                 
=================================================================
Total params: 4,022,850
Trainable params: 4,022,850
Non-trainable params: 0
_________________________________________________________________

To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.

Try it for the first example in the batch:

sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

This gives us, at each timestep, a prediction of the next character index:

sampled_indices
array([ 6, 52, 24, 12, 28, 62,  4, 36, 29,  3, 50, 14, 49, 26, 58, 40,  0,
       36,  5, 31, 55, 60, 47, 53, 64, 24, 46,  2,  4,  7, 40, 12, 40, 39,
       57, 21, 46, 38, 44, 26, 45, 53, 25, 39,  3, 24, 59, 44, 26, 60, 45,
       24,  6, 61,  8,  9,  3, 19, 25, 15, 19, 38, 24, 16, 24, 64, 21,  0,
       18, 65, 52,  7,  9, 49, 17, 37, 23, 62, 33, 43, 19, 24, 25, 37,  9,
       38, 56, 33,  8, 46,  9, 18, 21, 11, 11, 27,  7, 51, 14, 50])

Decode these to see the text predicted by this untrained model:

print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())
Input:
 b" kinsman! O sweet Juliet,\nThy beauty hath made me effeminate\nAnd in my temper soften'd valour's stee"

Next Char Predictions:
 b"'mK;Ow\\(WP!kAjMsa[UNK]W&RpuhnyKg \\),a;aZrHgYeMfnLZ!KteMufK'v-.!FLBFYKCKyH[UNK]Ezm,.jDXJwTdFKLX.YqT-g.EH::N,lAk"

Train the model

At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character.

Attach an optimizer, and a loss function

The standard tf.keras.losses.sparse_categorical_crossentropy loss function works in this case because it is applied across the last dimension of the predictions.

Because your model returns logits, you need to set the from_logits flag.

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", mean_loss)
Prediction shape:  (64, 100, 66)  # (batch_size, sequence_length, vocab_size)
Mean loss:         4.190459

A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:

tf.exp(mean_loss).numpy()
66.053085

Configure the training procedure using the tf.keras.Model.compile method. Use tf.keras.optimizers.Adam with default arguments and the loss function.

model.compile(optimizer='adam', loss=loss)

Configure checkpoints

Use a tf.keras.callbacks.ModelCheckpoint to ensure that checkpoints are saved during training:

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

Execute the training

To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training.

EPOCHS = 20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/20
172/172 [==============================] - 7s 25ms/step - loss: 2.6892
Epoch 2/20
172/172 [==============================] - 5s 25ms/step - loss: 1.9741
Epoch 3/20
172/172 [==============================] - 5s 24ms/step - loss: 1.6954
Epoch 4/20
172/172 [==============================] - 5s 25ms/step - loss: 1.5369
Epoch 5/20
172/172 [==============================] - 5s 25ms/step - loss: 1.4420
Epoch 6/20
172/172 [==============================] - 5s 25ms/step - loss: 1.3770
Epoch 7/20
172/172 [==============================] - 6s 24ms/step - loss: 1.3244
Epoch 8/20
172/172 [==============================] - 5s 24ms/step - loss: 1.2801
Epoch 9/20
172/172 [==============================] - 5s 24ms/step - loss: 1.2400
Epoch 10/20
172/172 [==============================] - 5s 24ms/step - loss: 1.2015
Epoch 11/20
172/172 [==============================] - 5s 24ms/step - loss: 1.1627
Epoch 12/20
172/172 [==============================] - 6s 25ms/step - loss: 1.1211
Epoch 13/20
172/172 [==============================] - 5s 25ms/step - loss: 1.0782
Epoch 14/20
172/172 [==============================] - 5s 25ms/step - loss: 1.0337
Epoch 15/20
172/172 [==============================] - 5s 25ms/step - loss: 0.9863
Epoch 16/20
172/172 [==============================] - 5s 24ms/step - loss: 0.9360
Epoch 17/20
172/172 [==============================] - 5s 24ms/step - loss: 0.8843
Epoch 18/20
172/172 [==============================] - 5s 24ms/step - loss: 0.8319
Epoch 19/20
172/172 [==============================] - 5s 24ms/step - loss: 0.7807
Epoch 20/20
172/172 [==============================] - 5s 24ms/step - loss: 0.7334

Generate text

The simplest way to generate text with this model is to run it in a loop, and keep track of the model's internal state as you execute it.

To generate text the model's output is fed back to the input

Each time you call the model you pass in some text and an internal state. The model returns a prediction for the next character and its new state. Pass the prediction and state back in to continue generating text.

The following makes a single step prediction:

class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

Run it in a loop to generate some text. Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.

start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)
ROMEO:
I am comfort; and it must retter you must fled.

WARWICK:
O excell the great gold times of kings,
The other strive thou seest am I, how does! Is remain
A' reckon'd in his breason should be hurt
To hide his country are man's, who stand allow.

SEBASTIAN:
His womb:
Most likewit, canst curn me to the duke?

CAMILLO:
Nicanne! can her he is common from my shield. Which raiment, when,
Manglence their voices that have fled from him,
Tell out the fire o' the strict people glass'd
My will be aspared and health of woe;
Please your scorns and young affections,
And for frost but kill watch his child.

AUTOLYCUS:
But I believe you? leven, and drip and bring it tree.
Is Richard no more of thee: their hearts as she were,
This neighbour airy conscience and notes,
Always that not in my throat, and have all very slain;
And every tree his life, so second children,
I heard their weap; when we it cannot choose
Come bohemios that must make me should
To end them. Who wanted saint--
Being cambred Barnardine  

________________________________________________________________________________

Run time: 2.4680824279785156

The easiest thing you can do to improve the results is to train it for longer (try EPOCHS = 30).

You can also experiment with a different start string, try adding another RNN layer to improve the model's accuracy, or adjust the temperature parameter to generate more or less random predictions.

If you want the model to generate text faster the easiest thing you can do is batch the text generation. In the example below the model generates 5 outputs in about the same time it took to generate 1 above.

start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result, '\n\n' + '_'*80)
print('\nRun time:', end - start)
tf.Tensor(
[b"ROMEO:\nTake her before, sir! The glass learned will serve: for a word,\nIs praise hath been by marriage, and these\nwounds that know thy will not certain of fastiny:\nGo, poor soul: I will be here to be gone.\n\nKING RICHARD III:\nConfess, reason! the languages of my loving liege!\nThou bloody quarther of thy back! so, farewell, frear\nApother's flesh and freence: so that ever\nHeeps he use if the people's eyes at heaven.\nCoursh thyself ha? little hours lie.\nThere speaks before some behold infection as I cannither\nTo strift the right of blood and senten'd fury; and be it should\nnot trial of it and rotten on thy winds.\nDeath, that same him, it strange to the Romans longing.\nNow, Abouts and mannor,\nEmpanish'd and spurn up wat,'t mad,\nDaughter Grey, come.\n\nDUKE VINCENTIO:\nNor I.\n\nSecond Senator:\nThe next way to him and Master Catesby,\nStops! may, their awe,--who noise is grown sor: 'she hours,\nCan this be kept of any infirm the canst thou canst wepper with the city\nIf, from my heart wepting, tent thee h"
 b"ROMEO:\nThe queen on hers! Clarence, till the king\nShall have said with the Lady Romeo's gianing.\n\nQUEEN:\nThen show she wanders to reproan.\n\nMIRANDA:\nBecause sit way but I fear, sir?\n\nPROSPERO:\nThou dost show me; and thou art not bitter out of fount:\nCome, but affare his plaints, for better there,\nOr waith, and they say, he spreak of socking Chidator,\nAnd he she speak of dock, knee where.'\nThis out for Kenarchment of this fair queen,\nNo, noble uncle. What, surstitualed than one file! Why dost not look us suret\nTo hard and kindly heart my queen of very sir,\nDo is scarce curses no further to be married;\nNot reverse the apply frown and cover all revolt;\nNor need and charm thou that again, mine in't is shield. Were I not, sir:\nI shall perfume it as a pratess; and\nshortly strive more of thine. He should not let me hear\nIt soon I come at the open of the inefall.\nAnd then I'll vold the silent crown\nThat we will also beat them, now arms must ease.\nWhat, hast thou finest at noble steed, and is\nThe nob"
 b"ROMEO:\n\nJULIET:\n'Tis time thou thy arm'd lord, that good itself,\nOur cold to minister of force and we behold the king is,\nNow appear it shall be more coy the heavens, and thy ancient gentleman\nWhich is no planet, nothing Hastings have\nWast smined a holpy of jest of a fire;\nOne thing, by thee, if not put the last,\nHer whipe as heaven fears me not he's good;\nWho lips, with thoughts can yield me seemer's store;\nFor no henge bragling forth and bashful king,\nAnd make his face sourness growing herein,\nAnd fled before the wall. Give me thy year\nIs look'd upon my kindness calleth hand: O raved\nWhere we'll descend no violent\nHath discredited that the most mal of straight: corrupt oursill me\nFriar Penerous art and unrest.\n\nCORIOLANUS:\nPray, fair son, and happy be thy name;\nIn chamber, Kate, colmed their abuse, which should\nbe placed in his friends than the albshate confound;\nThe knave in absence with your night.\n\nPRINCE EDWARD:\nA cuck, brave men all at once!\n\nNurse:\nThis answer to love thee! never lea"
 b"ROMEO:\nThe news with you? Return of all, I seek to us,\nAnd tell false Edward, King Edward's lips again\nTo Rome's discontentood: for I may live in stamp,\nWithout with a sign, a month about you!\nConsent by, nay, if thou dost, I Pereit that:\nGood queen, played into a slaguabour, are it is.\n\nAUTOLYCUS:\nO cause! why dost thou show me go along,\nFor I have made thy Edward had she known to France.\n\nKING RICHARD II:\nWe will be truched, shall we heard\nfrom you the day to thee a woman's face.\n\nEMILIA:\nMasters, hap the lies!\n\nSecond Watchman:\nHere's no as foolish term is the meal. Take happy sense.\nCousin of Buckingham, and Saint George Stanley's in hell!\nThe reason where he hath seen them move,\nAnd cowardly ones two of their houses, holding twine,\nAnd pluck'd on us women on the meater that gieds,\nYou shall tridul in the chamber, only in thy back!\nWilt thou respect the senseless vengeance of my father:\nThou has the case of these of his sivery.\n\nDUKE VINCENTIO:\nI'll hate; Richard infanctime their infyirm"
 b"ROMEO:\nThe loss of your imperiment spider, make whate'er stopping\nAttend an old tale man of sovereignty.\nNow Sister Barnardine! what a shame to king?\nWere he shall silen in my mildier maiderly as he\nshould Roman's faith in this, for it enchange\nMust him allow it. Speak; for I am sure, then I'll do you for\nthe world report the state of spirit, and go\nWitch! for she let that make that scarce choose\nComes reason all the due of all.\nThe immoran properly embraced, and\nat our enmmity to thine antium!\nThis blows appear, comfort: when my thoughts shall poison, years\nMuming him our souls must elder you: he is,\nWhen inducts might be asleeping for his life:\nMore strange daughters will help you to yourself.\nThe celvatain stayf in heavenly lance\nI cannot be behells, revive my standing he\nwhere I but craft of store pair'd with lie.\n\nProvost:\nWithin this fair converse is much better expence.\n\nANGELO:\nPlantagenet impeason with his grace flamed majesty\nAids, by the match, just die I like this hisble: we\nbeho"], shape=(5,), dtype=string) 

________________________________________________________________________________

Run time: 2.4250411987304688

Export the generator

This single-step model can easily be saved and restored, allowing you to use it anywhere a tf.saved_model is accepted.

tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.OneStep object at 0x7f785c577a90>, because it is not built.
2021-11-30 12:39:22.505061: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: one_step/assets
INFO:tensorflow:Assets written to: one_step/assets
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(100):
  next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
  result.append(next_char)

print(tf.strings.join(result)[0].numpy().decode("utf-8"))
ROMEO:
Happiness he comes! by awe? I would not let-
Mutus meebly bear them for the free destroying thee:
I

Advanced: Customized Training

The above training procedure is simple, but does not give you much control. It uses teacher-forcing which prevents bad predictions from being fed back to the model, so the model never learns to recover from mistakes.

So now that you've seen how to run the model manually next you'll implement the training loop. This gives a starting point if, for example, you want to implement curriculum learning to help stabilize the model's open-loop output.

The most important part of a custom training loop is the train step function.

Use tf.GradientTape to track the gradients. You can learn more about this approach by reading the eager execution guide.

The basic procedure is:

  1. Execute the model and calculate the loss under a tf.GradientTape.
  2. Calculate the updates and apply them to the model using the optimizer.
class CustomTraining(MyModel):
  @tf.function
  def train_step(self, inputs):
      inputs, labels = inputs
      with tf.GradientTape() as tape:
          predictions = self(inputs, training=True)
          loss = self.loss(labels, predictions)
      grads = tape.gradient(loss, model.trainable_variables)
      self.optimizer.apply_gradients(zip(grads, model.trainable_variables))

      return {'loss': loss}

The above implementation of the train_step method follows Keras' train_step conventions. This is optional, but it allows you to change the behavior of the train step and still use keras' Model.compile and Model.fit methods.

model = CustomTraining(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(dataset, epochs=1)
172/172 [==============================] - 7s 24ms/step - loss: 2.7157
<keras.callbacks.History at 0x7f784c347d10>

Or if you need more control, you can write your own complete custom training loop:

EPOCHS = 10

mean = tf.metrics.Mean()

for epoch in range(EPOCHS):
    start = time.time()

    mean.reset_states()
    for (batch_n, (inp, target)) in enumerate(dataset):
        logs = model.train_step([inp, target])
        mean.update_state(logs['loss'])

        if batch_n % 50 == 0:
            template = f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}"
            print(template)

    # saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch))

    print()
    print(f'Epoch {epoch+1} Loss: {mean.result().numpy():.4f}')
    print(f'Time taken for 1 epoch {time.time() - start:.2f} sec')
    print("_"*80)

model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 2.1894
Epoch 1 Batch 50 Loss 2.0428
Epoch 1 Batch 100 Loss 1.9478
Epoch 1 Batch 150 Loss 1.8134

Epoch 1 Loss: 1.9878
Time taken for 1 epoch 5.99 sec
________________________________________________________________________________
Epoch 2 Batch 0 Loss 1.8471
Epoch 2 Batch 50 Loss 1.7652
Epoch 2 Batch 100 Loss 1.6805
Epoch 2 Batch 150 Loss 1.6089

Epoch 2 Loss: 1.7063
Time taken for 1 epoch 5.22 sec
________________________________________________________________________________
Epoch 3 Batch 0 Loss 1.5877
Epoch 3 Batch 50 Loss 1.5644
Epoch 3 Batch 100 Loss 1.6012
Epoch 3 Batch 150 Loss 1.5249

Epoch 3 Loss: 1.5441
Time taken for 1 epoch 5.39 sec
________________________________________________________________________________
Epoch 4 Batch 0 Loss 1.4675
Epoch 4 Batch 50 Loss 1.3992
Epoch 4 Batch 100 Loss 1.4202
Epoch 4 Batch 150 Loss 1.4764

Epoch 4 Loss: 1.4450
Time taken for 1 epoch 5.19 sec
________________________________________________________________________________
Epoch 5 Batch 0 Loss 1.3906
Epoch 5 Batch 50 Loss 1.3484
Epoch 5 Batch 100 Loss 1.3649
Epoch 5 Batch 150 Loss 1.3644

Epoch 5 Loss: 1.3776
Time taken for 1 epoch 5.48 sec
________________________________________________________________________________
Epoch 6 Batch 0 Loss 1.2946
Epoch 6 Batch 50 Loss 1.3350
Epoch 6 Batch 100 Loss 1.2798
Epoch 6 Batch 150 Loss 1.3575

Epoch 6 Loss: 1.3250
Time taken for 1 epoch 5.31 sec
________________________________________________________________________________
Epoch 7 Batch 0 Loss 1.1956
Epoch 7 Batch 50 Loss 1.2781
Epoch 7 Batch 100 Loss 1.2646
Epoch 7 Batch 150 Loss 1.3028

Epoch 7 Loss: 1.2797
Time taken for 1 epoch 5.43 sec
________________________________________________________________________________
Epoch 8 Batch 0 Loss 1.2646
Epoch 8 Batch 50 Loss 1.2904
Epoch 8 Batch 100 Loss 1.2497
Epoch 8 Batch 150 Loss 1.2201

Epoch 8 Loss: 1.2387
Time taken for 1 epoch 5.45 sec
________________________________________________________________________________
Epoch 9 Batch 0 Loss 1.1981
Epoch 9 Batch 50 Loss 1.1663
Epoch 9 Batch 100 Loss 1.2200
Epoch 9 Batch 150 Loss 1.1986

Epoch 9 Loss: 1.1983
Time taken for 1 epoch 5.25 sec
________________________________________________________________________________
Epoch 10 Batch 0 Loss 1.1464
Epoch 10 Batch 50 Loss 1.1448
Epoch 10 Batch 100 Loss 1.1566
Epoch 10 Batch 150 Loss 1.1913

Epoch 10 Loss: 1.1587
Time taken for 1 epoch 5.44 sec
________________________________________________________________________________