Text generation with an RNN

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.

This tutorial includes runnable code implemented using tf.keras and eager execution. The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt "Q":

QUEENE:
I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.

BISHOP OF ELY:
Marry, and will, my lord, to weep in such a one were prettiest;
Yet now I was adopted heir
Of the world's lamentable day,
To watch the next way with his father with his face?

ESCALUS:
The cause why then we are all resolved more sons.

VOLUMNIA:
O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,
And love and pale as any will to that word.

QUEEN ELIZABETH:
But how long have I heard the soul for this world,
And show his hands of life be proved to stand.

PETRUCHIO:
I say he look'd on, if I must be content
To stay him from the fatal of our country's bliss.
His lordship pluck'd from this sentence then for prey,
And then let us twain, being the moon,
were she such a case as fills m

While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:

  • The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.

  • The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.

  • As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.

Setup

Import TensorFlow and other libraries

import tensorflow as tf

import numpy as np
import os
import time
2022-12-14 13:32:12.891190: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 13:32:12.891284: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 13:32:12.891293: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Download the Shakespeare dataset

Change the following line to run this code on your own data.

path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
1115394/1115394 [==============================] - 0s 0us/step

Read the data

First, look in the text:

# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')
Length of text: 1115394 characters
# Take a look at the first 250 characters in text
print(text[:250])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.
# The unique characters in the file
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')
65 unique characters

Process the text

Vectorize the text

Before training, you need to convert the strings to a numerical representation.

The tf.keras.layers.StringLookup layer can convert each character into a numeric ID. It just needs the text to be split into tokens first.

example_texts = ['abcdefg', 'xyz']

chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

Now create the tf.keras.layers.StringLookup layer:

ids_from_chars = tf.keras.layers.StringLookup(
    vocabulary=list(vocab), mask_token=None)

It converts from tokens to character IDs:

ids = ids_from_chars(chars)
ids
<tf.RaggedTensor [[40, 41, 42, 43, 44, 45, 46], [63, 64, 65]]>

Since the goal of this tutorial is to generate text, it will also be important to invert this representation and recover human-readable strings from it. For this you can use tf.keras.layers.StringLookup(..., invert=True).

chars_from_ids = tf.keras.layers.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

This layer recovers the characters from the vectors of IDs, and returns them as a tf.RaggedTensor of characters:

chars = chars_from_ids(ids)
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

You can tf.strings.reduce_join to join the characters back into strings.

tf.strings.reduce_join(chars, axis=-1).numpy()
array([b'abcdefg', b'xyz'], dtype=object)
def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

The prediction task

Given a character, or a sequence of characters, what is the most probable next character? This is the task you're training the model to perform. The input to the model will be a sequence of characters, and you train the model to predict the output—the following character at each time step.

Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?

Create training examples and targets

Next divide the text into example sequences. Each input sequence will contain seq_length characters from the text.

For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.

So break the text into chunks of seq_length+1. For example, say seq_length is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".

To do this first use the tf.data.Dataset.from_tensor_slices function to convert the text vector into a stream of character indices.

all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids
<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([19, 48, 57, ..., 46,  9,  1])>
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))
F
i
r
s
t
 
C
i
t
i
seq_length = 100

The batch method lets you easily convert these individual characters to sequences of the desired size.

sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

for seq in sequences.take(1):
  print(chars_from_ids(seq))
tf.Tensor(
[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'
 b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'
 b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'
 b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'
 b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e'
 b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i'
 b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y'
 b'o' b'u' b' '], shape=(101,), dtype=string)

It's easier to see what this is doing if you join the tokens back into strings:

for seq in sequences.take(5):
  print(text_from_ids(seq).numpy())
b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'

For training you'll need a dataset of (input, label) pairs. Where input and label are sequences. At each time step the input is the current character and the label is the next character.

Here's a function that takes a sequence as input, duplicates, and shifts it to align the input and label for each timestep:

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text
split_input_target(list("Tensorflow"))
(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'],
 ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())
Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '

Create training batches

You used tf.data to split the text into manageable sequences. But before feeding this data into the model, you need to shuffle the data and pack it into batches.

# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset
<PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>

Build The Model

This section defines the model as a keras.Model subclass (For details see Making new Layers and Models via subclassing).

This model has three layers:

  • tf.keras.layers.Embedding: The input layer. A trainable lookup table that will map each character-ID to a vector with embedding_dim dimensions;
  • tf.keras.layers.GRU: A type of RNN with size units=rnn_units (You can also use an LSTM layer here.)
  • tf.keras.layers.Dense: The output layer, with vocab_size outputs. It outputs one logit for each character in the vocabulary. These are the log-likelihood of each character according to the model.
# Length of the vocabulary in StringLookup Layer
vocab_size = len(ids_from_chars.get_vocabulary())

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024
class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x
model = MyModel(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character:

A drawing of the data passing through the model

Try the model

Now run the model to see that it behaves as expected.

First check the shape of the output:

for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 66) # (batch_size, sequence_length, vocab_size)

In the above example the sequence length of the input is 100 but the model can be run on inputs of any length:

model.summary()
Model: "my_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       multiple                  16896     
                                                                 
 gru (GRU)                   multiple                  3938304   
                                                                 
 dense (Dense)               multiple                  67650     
                                                                 
=================================================================
Total params: 4,022,850
Trainable params: 4,022,850
Non-trainable params: 0
_________________________________________________________________

To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.

Try it for the first example in the batch:

sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

This gives us, at each timestep, a prediction of the next character index:

sampled_indices
array([49, 56, 39, 28,  8, 35, 34,  0, 11, 12, 34, 36,  4, 34, 50, 12, 12,
       54, 28, 46, 24, 41, 29, 16, 11, 59,  5, 62, 15, 33,  8, 61,  3, 30,
       10, 37, 57, 48, 41, 61, 11, 55, 43, 45, 57, 51,  6, 36,  8, 18, 32,
       30, 49,  2, 35, 47, 25, 51, 43, 49, 11, 50, 64, 37, 34, 29, 43,  4,
       51, 19,  9, 12, 18, 48, 31,  7, 35, 30, 38,  1, 35, 49, 22, 65, 48,
       13, 26, 12, 56, 25, 30, 41, 28, 27,  1, 62, 41, 34,  4, 29])

Decode these to see the text predicted by this untrained model:

print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())
Input:
 b"! I come, I come!\nWho knocks so hard? whence come you? what's your will?\n\nNurse:\n\nFRIAR LAURENCE:\nWe"

Next Char Predictions:
 b"jqZO-VU[UNK]:;UW\\(Uk;;oOgKbPC:t&wBT-v!Q3Xribv:pdfrl'W-ESQj VhLldj:kyXUPd\\)lF.;EiR,VQY\nVjIzi?M;qLQbON\nwbU$P"

Train the model

At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character.

Attach an optimizer, and a loss function

The standard tf.keras.losses.sparse_categorical_crossentropy loss function works in this case because it is applied across the last dimension of the predictions.

Because your model returns logits, you need to set the from_logits flag.

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_mean_loss = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", example_batch_mean_loss)
Prediction shape:  (64, 100, 66)  # (batch_size, sequence_length, vocab_size)
Mean loss:         tf.Tensor(4.1897926, shape=(), dtype=float32)

A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:

tf.exp(example_batch_mean_loss).numpy()
66.0091

Configure the training procedure using the tf.keras.Model.compile method. Use tf.keras.optimizers.Adam with default arguments and the loss function.

model.compile(optimizer='adam', loss=loss)

Configure checkpoints

Use a tf.keras.callbacks.ModelCheckpoint to ensure that checkpoints are saved during training:

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

Execute the training

To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training.

EPOCHS = 20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/20
172/172 [==============================] - 10s 41ms/step - loss: 2.7038
Epoch 2/20
172/172 [==============================] - 8s 38ms/step - loss: 1.9803
Epoch 3/20
172/172 [==============================] - 7s 37ms/step - loss: 1.7038
Epoch 4/20
172/172 [==============================] - 8s 38ms/step - loss: 1.5428
Epoch 5/20
172/172 [==============================] - 8s 38ms/step - loss: 1.4445
Epoch 6/20
172/172 [==============================] - 7s 37ms/step - loss: 1.3770
Epoch 7/20
172/172 [==============================] - 7s 37ms/step - loss: 1.3242
Epoch 8/20
172/172 [==============================] - 7s 37ms/step - loss: 1.2800
Epoch 9/20
172/172 [==============================] - 7s 37ms/step - loss: 1.2399
Epoch 10/20
172/172 [==============================] - 7s 37ms/step - loss: 1.1986
Epoch 11/20
172/172 [==============================] - 7s 37ms/step - loss: 1.1592
Epoch 12/20
172/172 [==============================] - 7s 37ms/step - loss: 1.1176
Epoch 13/20
172/172 [==============================] - 8s 37ms/step - loss: 1.0734
Epoch 14/20
172/172 [==============================] - 7s 37ms/step - loss: 1.0273
Epoch 15/20
172/172 [==============================] - 7s 37ms/step - loss: 0.9769
Epoch 16/20
172/172 [==============================] - 7s 37ms/step - loss: 0.9258
Epoch 17/20
172/172 [==============================] - 8s 37ms/step - loss: 0.8732
Epoch 18/20
172/172 [==============================] - 7s 37ms/step - loss: 0.8206
Epoch 19/20
172/172 [==============================] - 8s 37ms/step - loss: 0.7692
Epoch 20/20
172/172 [==============================] - 7s 37ms/step - loss: 0.7203

Generate text

The simplest way to generate text with this model is to run it in a loop, and keep track of the model's internal state as you execute it.

To generate text the model's output is fed back to the input

Each time you call the model you pass in some text and an internal state. The model returns a prediction for the next character and its new state. Pass the prediction and state back in to continue generating text.

The following makes a single step prediction:

class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

Run it in a loop to generate some text. Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.

start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)
ROMEO:
Come, Kate, and wear thee! or a woman of prey,
And speak it thunder, not in my breed!
Will you two go to serve the gentle joint in this?
Thou king'd my joy, upon their swords!

KATHARINA:
Mistress, the case is too friendship: For which
else it be done:
His goods that he dress'd me not what I would
And Clifford stay all toward with a helmer
Am bour and minister. Good news, go with me;
In coroor lady, thirty years!

JOHN OF GAUNT:
O, Delatul king! How vanged his contract wilt tarry
To minister of the birth and tedious?
Thy tears are specked, church and hollow breath?

ISABELLA:
My oldest! I warrant you; and turn good and
Thy brother's present penselance of this.
So speak? O ray, did presently grant care woe.

HASTINGS:
What, with thy griefs, who I am loath to blame,
To this it art.

LUCENTIO:
I' like, I do indeed, not reason; knice in this
Until the enthous tault to this likewemp an:
I have lohing theer their shamed, took me;
Look to thy bark: I'll find the way; who read
Shall know your 

________________________________________________________________________________

Run time: 2.644033193588257

The easiest thing you can do to improve the results is to train it for longer (try EPOCHS = 30).

You can also experiment with a different start string, try adding another RNN layer to improve the model's accuracy, or adjust the temperature parameter to generate more or less random predictions.

If you want the model to generate text faster the easiest thing you can do is batch the text generation. In the example below the model generates 5 outputs in about the same time it took to generate 1 above.

start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result, '\n\n' + '_'*80)
print('\nRun time:', end - start)
tf.Tensor(
[b"ROMEO:\nMy work is that gold not for loving, prisoners.\n\nDUKE OF YORK:\nSpeak it again, we did commend to rule for us\nTo Blushion, that may be your court quickly.\nAsk floce, with all the loss that lies disadvension\nThus with all pracious tarder where Oxford comes!\n\nLUCENTIO:\nI start and frank, I care not flow alan.\nDespiteful tidon! shalt thou go to debish;\nMy friends, I have drawn your rors of mine,\nBut with all much sorrow ends,\nSome valumberous cradged country.\nAnd then I would he had not, like a parment\nI drink my wager or like.\n\nEDWARD:\nO brother Gloucester, for thy lord's! millson! make thee gone:\nLet Romeo, there did miss them; they\nshall have no palures: sir, points\nAnd very scaps when I may come to thundey.\n\nBUCKINGHAM:\nWhy, then he swore to make thy spiders at tears.\n\nTINGL O:\nGod save forth prince.\n\nMARCIUS:\nHe thy eyes do weep and deliver death.\nThus first mark hold! If I must pay your husband\nAnd halt my either blood at servith ope.\nA, my three-e'er it were ann lourn were their\nTh"
 b"ROMEO:\nCome, be you both a young prince to be your mooning,\nthat we will bear so pall'd your little.\n\nCAMILLO:\nWhen you may spare let her die for being so?\nOr shall I do? a sacrifice, that\nLet noble compate bitter, being a thousand,--\n\nAETES:\nGo to, go both.\n\nDUKE OF YORK:\nMadam, good queen.\n\nPAULINA:\nNot pause.\n\nTut you the orjury of the people?\n\nSICINIUS:\nHe is in wrath, as we be tagerer;\nAnd you were so put do attend thee in her.\n\nQUEEr Murderer:\nTake't; I'll say amen, sweet Oxford; what your highness,\nWe lield them myself, and not against you:\nHe does well in this destity vow;\nThis must rie his influn their names Roter.\n\nBUSHY:\nThe hourt upon us thanks,\nThy beauty hath all spenting last to tarnied,\nHave been debt buckle kindly confused\nHow to call him, and that cannot tell me to the sea,\nAnd unrood on eating and rest again to\nApparents to my enemies: you be rid\nmother: hie thee this widow in whose inclines\nPardon in this falling villain should\nlook upon his shroud; and then the shepherds"
 b"ROMEO:\nMarry, swear, in tooch mine elemies! Thou art too much,\nThat girruft which they scart as I expirious;\nBut in these scruples real hope, this one is after'd in fullow\nOf thy secrets to the ensmity in their myself,\nThat beart the lists and to be obey.\nBank lices but this: both wives, take up the better.\nI hope the lords this speechless labour,\nThat trems switting in the time hath grieved my ajused preVent.\nHix foot my tongue whereof I do; we shall have know\nThe right unbur fortune to the Tower.\n\nKING EDWARD IV:\nNext queen and make the hungry centrements\nBe known unworstipul; which henceforth have accept of?\nhere, by a thief and together: stay ay, what\nI came from heaven, mine honour,\nSage me down with ithress: and therein left am\nNorefus to to them and unto the tribunes. Make them dwfly.\nI must be cure to London on me--hold, made\nWhat hap the greater man.\n\nPROSPERO:\nThough I will discredit\nThat murder me so happy was never\nMyself your eyes addiams and Warwick to thy share,\nAnd kiss your "
 b"ROMEO:\nCome, I'll turn thee tell me I go: and well\nThey love they seek the sceptre from one so.\n\nJULIET:\nYet set down with thee departed of night\nCome that sens wed; nature cur man is\nconfessor of my mistress'd wombs.\nThis rost as I have with me too:\nThou didst depend me to the Capulets\nMake the tribunes of my power.\n\nCOMINIUS:\nAy, mistress Bolingbroke,\nI would they were destroy'd my good Camillo,\nYou sleep in blist envy Pompery in\nthe hostest wite, which is more remedies in one.\n\nLUCIO:\nJesu pair of woe, when I was command, live to\nRave like an hour.\n\nDUKE OF YORK:\nCome tomachions, 'jeach'd with blood, and only means\nFrom the people in their and roctors!\nCurses not a worthy man, and he not speak?\nYour\nput in mere confessible. I am sure,\nI would as flee--your suit is servil\nOn the fix'd of the contrary, and prepare\nnot to requite their swords. The self-stial' gaved\nto do them affairs itself to dials advantage of heart.\nTherefore, no better tale: the ladies of the young plenty death.\nO, pluck"
 b"ROMEO:\nPeace, peace Mercuryain'd with thy realm; but I would\nHave saver'd with that appellant and to bree,\nSuspict and perpetual, bind of greaters becomes\nMore untience that deliver you our complaints to do thy heat,\nPlease you to rests, if this blood procked\nThe uncle of patience and suit all,\nAnd in three precious wards tarket to rest\nUnto the crown and not rewers sworn your holy robularity.\n\nBRAKENBUR:\nHaths, and three island and flot were likelitors,\nAnd finds be no siming, friends,' quoth he;\n'Twas done: and therefore I'll unto the word:\nBut what, in me an true love's name? Is not my duty;\nBut here is Camillo was a jaw of ill.\nI doubt not, I thunk together,\nAnd serve me soon as we took act.\n\nGONZAL:\nAy, thus 'twill kiss this vice of it.\n\nKATHARINA:\nWhat ne'er art mady ministers in thy hard?\nWhen haste your sleep is balk in them? but my true reputation\nTo Annot remiss! Think I am he, become me\nnot in't; the lark intolence, which should\nsuch abused winds unfore like minds: they saw these "], shape=(5,), dtype=string) 

________________________________________________________________________________

Run time: 2.5831151008605957

Export the generator

This single-step model can easily be saved and restored, allowing you to use it anywhere a tf.saved_model is accepted.

tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.OneStep object at 0x7fdc704a4d60>, because it is not built.
WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: one_step/assets
INFO:tensorflow:Assets written to: one_step/assets
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(100):
  next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
  result.append(next_char)

print(tf.strings.join(result)[0].numpy().decode("utf-8"))
ROMEO:
My affairs doth unto the house of York
Us! Ruslife, my heart! Go, be so strange so.

MONTAGUE:
I wo

Advanced: Customized Training

The above training procedure is simple, but does not give you much control. It uses teacher-forcing which prevents bad predictions from being fed back to the model, so the model never learns to recover from mistakes.

So now that you've seen how to run the model manually next you'll implement the training loop. This gives a starting point if, for example, you want to implement curriculum learning to help stabilize the model's open-loop output.

The most important part of a custom training loop is the train step function.

Use tf.GradientTape to track the gradients. You can learn more about this approach by reading the eager execution guide.

The basic procedure is:

  1. Execute the model and calculate the loss under a tf.GradientTape.
  2. Calculate the updates and apply them to the model using the optimizer.
class CustomTraining(MyModel):
  @tf.function
  def train_step(self, inputs):
      inputs, labels = inputs
      with tf.GradientTape() as tape:
          predictions = self(inputs, training=True)
          loss = self.loss(labels, predictions)
      grads = tape.gradient(loss, model.trainable_variables)
      self.optimizer.apply_gradients(zip(grads, model.trainable_variables))

      return {'loss': loss}

The above implementation of the train_step method follows Keras' train_step conventions. This is optional, but it allows you to change the behavior of the train step and still use keras' Model.compile and Model.fit methods.

model = CustomTraining(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(dataset, epochs=1)
172/172 [==============================] - 10s 39ms/step - loss: 2.7237
<keras.callbacks.History at 0x7fdc3031ea60>

Or if you need more control, you can write your own complete custom training loop:

EPOCHS = 10

mean = tf.metrics.Mean()

for epoch in range(EPOCHS):
    start = time.time()

    mean.reset_states()
    for (batch_n, (inp, target)) in enumerate(dataset):
        logs = model.train_step([inp, target])
        mean.update_state(logs['loss'])

        if batch_n % 50 == 0:
            template = f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}"
            print(template)

    # saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch))

    print()
    print(f'Epoch {epoch+1} Loss: {mean.result().numpy():.4f}')
    print(f'Time taken for 1 epoch {time.time() - start:.2f} sec')
    print("_"*80)

model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 2.1982
Epoch 1 Batch 50 Loss 2.0437
Epoch 1 Batch 100 Loss 1.9558
Epoch 1 Batch 150 Loss 1.9079

Epoch 1 Loss: 1.9976
Time taken for 1 epoch 8.78 sec
________________________________________________________________________________
Epoch 2 Batch 0 Loss 1.8498
Epoch 2 Batch 50 Loss 1.7517
Epoch 2 Batch 100 Loss 1.7125
Epoch 2 Batch 150 Loss 1.7343

Epoch 2 Loss: 1.7276
Time taken for 1 epoch 7.37 sec
________________________________________________________________________________
Epoch 3 Batch 0 Loss 1.6001
Epoch 3 Batch 50 Loss 1.5886
Epoch 3 Batch 100 Loss 1.5683
Epoch 3 Batch 150 Loss 1.5864

Epoch 3 Loss: 1.5673
Time taken for 1 epoch 7.40 sec
________________________________________________________________________________
Epoch 4 Batch 0 Loss 1.4907
Epoch 4 Batch 50 Loss 1.4504
Epoch 4 Batch 100 Loss 1.4559
Epoch 4 Batch 150 Loss 1.5050

Epoch 4 Loss: 1.4654
Time taken for 1 epoch 7.36 sec
________________________________________________________________________________
Epoch 5 Batch 0 Loss 1.4318
Epoch 5 Batch 50 Loss 1.3978
Epoch 5 Batch 100 Loss 1.3867
Epoch 5 Batch 150 Loss 1.3937

Epoch 5 Loss: 1.3965
Time taken for 1 epoch 7.61 sec
________________________________________________________________________________
Epoch 6 Batch 0 Loss 1.3197
Epoch 6 Batch 50 Loss 1.3280
Epoch 6 Batch 100 Loss 1.3472
Epoch 6 Batch 150 Loss 1.3426

Epoch 6 Loss: 1.3437
Time taken for 1 epoch 7.33 sec
________________________________________________________________________________
Epoch 7 Batch 0 Loss 1.2784
Epoch 7 Batch 50 Loss 1.2897
Epoch 7 Batch 100 Loss 1.2868
Epoch 7 Batch 150 Loss 1.3265

Epoch 7 Loss: 1.2992
Time taken for 1 epoch 7.37 sec
________________________________________________________________________________
Epoch 8 Batch 0 Loss 1.2435
Epoch 8 Batch 50 Loss 1.2703
Epoch 8 Batch 100 Loss 1.2458
Epoch 8 Batch 150 Loss 1.2640

Epoch 8 Loss: 1.2583
Time taken for 1 epoch 7.43 sec
________________________________________________________________________________
Epoch 9 Batch 0 Loss 1.1836
Epoch 9 Batch 50 Loss 1.1948
Epoch 9 Batch 100 Loss 1.2615
Epoch 9 Batch 150 Loss 1.2307

Epoch 9 Loss: 1.2215
Time taken for 1 epoch 7.37 sec
________________________________________________________________________________
Epoch 10 Batch 0 Loss 1.1510
Epoch 10 Batch 50 Loss 1.1569
Epoch 10 Batch 100 Loss 1.1545
Epoch 10 Batch 150 Loss 1.2415

Epoch 10 Loss: 1.1831
Time taken for 1 epoch 7.60 sec
________________________________________________________________________________