Have a question? Connect with the community at the TensorFlow Forum Visit Forum

Text generation with an RNN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.

This tutorial includes runnable code implemented using tf.keras and eager execution. The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt "Q":

QUEENE:
I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.

BISHOP OF ELY:
Marry, and will, my lord, to weep in such a one were prettiest;
Yet now I was adopted heir
Of the world's lamentable day,
To watch the next way with his father with his face?

ESCALUS:
The cause why then we are all resolved more sons.

VOLUMNIA:
O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,
And love and pale as any will to that word.

QUEEN ELIZABETH:
But how long have I heard the soul for this world,
And show his hands of life be proved to stand.

PETRUCHIO:
I say he look'd on, if I must be content
To stay him from the fatal of our country's bliss.
His lordship pluck'd from this sentence then for prey,
And then let us twain, being the moon,
were she such a case as fills m

While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:

  • The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.

  • The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.

  • As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.

Setup

Import TensorFlow and other libraries

import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

import numpy as np
import os
import time

Download the Shakespeare dataset

Change the following line to run this code on your own data.

path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
1122304/1115394 [==============================] - 0s 0us/step

Read the data

First, look in the text:

# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')
Length of text: 1115394 characters
# Take a look at the first 250 characters in text
print(text[:250])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.
# The unique characters in the file
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')
65 unique characters

Process the text

Vectorize the text

Before training, you need to convert the strings to a numerical representation.

The preprocessing.StringLookup layer can convert each character into a numeric ID. It just needs the text to be split into tokens first.

example_texts = ['abcdefg', 'xyz']

chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

Now create the preprocessing.StringLookup layer:

ids_from_chars = preprocessing.StringLookup(
    vocabulary=list(vocab), mask_token=None)

It converts form tokens to character IDs:

ids = ids_from_chars(chars)
ids
<tf.RaggedTensor [[40, 41, 42, 43, 44, 45, 46], [63, 64, 65]]>

Since the goal of this tutorial is to generate text, it will also be important to invert this representation and recover human-readable strings from it. For this you can use preprocessing.StringLookup(..., invert=True).

chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

This layer recovers the characters from the vectors of IDs, and returns them as a tf.RaggedTensor of characters:

chars = chars_from_ids(ids)
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

You can tf.strings.reduce_join to join the characters back into strings.

tf.strings.reduce_join(chars, axis=-1).numpy()
array([b'abcdefg', b'xyz'], dtype=object)
def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

The prediction task

Given a character, or a sequence of characters, what is the most probable next character? This is the task you're training the model to perform. The input to the model will be a sequence of characters, and you train the model to predict the output—the following character at each time step.

Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?

Create training examples and targets

Next divide the text into example sequences. Each input sequence will contain seq_length characters from the text.

For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.

So break the text into chunks of seq_length+1. For example, say seq_length is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".

To do this first use the tf.data.Dataset.from_tensor_slices function to convert the text vector into a stream of character indices.

all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids
<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([19, 48, 57, ..., 46,  9,  1])>
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))
F
i
r
s
t
 
C
i
t
i
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)

The batch method lets you easily convert these individual characters to sequences of the desired size.

sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

for seq in sequences.take(1):
  print(chars_from_ids(seq))
tf.Tensor(
[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'
 b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'
 b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'
 b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'
 b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e'
 b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i'
 b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y'
 b'o' b'u' b' '], shape=(101,), dtype=string)

It's easier to see what this is doing if you join the tokens back into strings:

for seq in sequences.take(5):
  print(text_from_ids(seq).numpy())
b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'

For training you'll need a dataset of (input, label) pairs. Where input and label are sequences. At each time step the input is the current character and the label is the next character.

Here's a function that takes a sequence as input, duplicates, and shifts it to align the input and label for each timestep:

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text
split_input_target(list("Tensorflow"))
(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'],
 ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())
Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '

Create training batches

You used tf.data to split the text into manageable sequences. But before feeding this data into the model, you need to shuffle the data and pack it into batches.

# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset
<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

Build The Model

This section defines the model as a keras.Model subclass (For details see Making new Layers and Models via subclassing).

This model has three layers:

  • tf.keras.layers.Embedding: The input layer. A trainable lookup table that will map each character-ID to a vector with embedding_dim dimensions;
  • tf.keras.layers.GRU: A type of RNN with size units=rnn_units (You can also use an LSTM layer here.)
  • tf.keras.layers.Dense: The output layer, with vocab_size outputs. It outputs one logit for each character in the vocabulary. These are the log-likelihood of each character according to the model.
# Length of the vocabulary in chars
vocab_size = len(vocab)

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024
class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x
model = MyModel(
    # Be sure the vocabulary size matches the `StringLookup` layers.
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character:

A drawing of the data passing through the model

Try the model

Now run the model to see that it behaves as expected.

First check the shape of the output:

for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 66) # (batch_size, sequence_length, vocab_size)

In the above example the sequence length of the input is 100 but the model can be run on inputs of any length:

model.summary()
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        multiple                  16896     
_________________________________________________________________
gru (GRU)                    multiple                  3938304   
_________________________________________________________________
dense (Dense)                multiple                  67650     
=================================================================
Total params: 4,022,850
Trainable params: 4,022,850
Non-trainable params: 0
_________________________________________________________________

To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.

Try it for the first example in the batch:

sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

This gives us, at each timestep, a prediction of the next character index:

sampled_indices
array([60,  4, 65, 42, 39, 24, 16, 50, 49,  8, 10, 26, 53,  5,  2, 64, 42,
       23,  1, 41, 50, 42, 57, 30, 40, 51, 45, 63, 48, 38, 51, 51, 32, 58,
       29, 61, 58, 33,  2, 19, 17, 61, 39, 21, 44,  8, 40, 24, 60, 64, 60,
       63,  6, 13, 51,  2, 51, 23, 41, 37, 61, 27, 33,  5, 24, 64, 11, 13,
       55, 65,  1, 19,  0, 15, 29,  7, 56, 57, 36, 30,  8, 38, 50, 33,  0,
       41, 42,  0, 37, 36, 29, 16, 13, 56, 64, 65, 17, 27, 17, 28])

Decode these to see the text predicted by this untrained model:

print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())
Input:
 b"INIUS:\nHe's sentenced; no more hearing.\n\nCOMINIUS:\nLet me speak:\nI have been consul, and can show fo"

Next Char Predictions:
 b"u$zcZKCkj-3Mn& ycJ\nbkcrQalfxiYllSsPvsT FDvZHe-aKuyux'?l lJbXvNT&Ky:?pz\nF[UNK]BP,qrWQ-YkT[UNK]bc[UNK]XWPC?qyzDNDO"

Train the model

At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character.

Attach an optimizer, and a loss function

The standard tf.keras.losses.sparse_categorical_crossentropy loss function works in this case because it is applied across the last dimension of the predictions.

Because your model returns logits, you need to set the from_logits flag.

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", mean_loss)
Prediction shape:  (64, 100, 66)  # (batch_size, sequence_length, vocab_size)
Mean loss:         4.1912417

A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:

tf.exp(mean_loss).numpy()
66.10483

Configure the training procedure using the tf.keras.Model.compile method. Use tf.keras.optimizers.Adam with default arguments and the loss function.

model.compile(optimizer='adam', loss=loss)

Configure checkpoints

Use a tf.keras.callbacks.ModelCheckpoint to ensure that checkpoints are saved during training:

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

Execute the training

To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training.

EPOCHS = 20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/20
172/172 [==============================] - 6s 25ms/step - loss: 2.7484
Epoch 2/20
172/172 [==============================] - 5s 25ms/step - loss: 2.0113
Epoch 3/20
172/172 [==============================] - 5s 25ms/step - loss: 1.7361
Epoch 4/20
172/172 [==============================] - 5s 25ms/step - loss: 1.5733
Epoch 5/20
172/172 [==============================] - 5s 25ms/step - loss: 1.4716
Epoch 6/20
172/172 [==============================] - 5s 25ms/step - loss: 1.4012
Epoch 7/20
172/172 [==============================] - 5s 25ms/step - loss: 1.3479
Epoch 8/20
172/172 [==============================] - 5s 25ms/step - loss: 1.3025
Epoch 9/20
172/172 [==============================] - 5s 25ms/step - loss: 1.2632
Epoch 10/20
172/172 [==============================] - 5s 25ms/step - loss: 1.2257
Epoch 11/20
172/172 [==============================] - 5s 25ms/step - loss: 1.1876
Epoch 12/20
172/172 [==============================] - 5s 25ms/step - loss: 1.1495
Epoch 13/20
172/172 [==============================] - 5s 25ms/step - loss: 1.1101
Epoch 14/20
172/172 [==============================] - 5s 25ms/step - loss: 1.0685
Epoch 15/20
172/172 [==============================] - 5s 25ms/step - loss: 1.0240
Epoch 16/20
172/172 [==============================] - 5s 25ms/step - loss: 0.9762
Epoch 17/20
172/172 [==============================] - 5s 25ms/step - loss: 0.9275
Epoch 18/20
172/172 [==============================] - 5s 25ms/step - loss: 0.8771
Epoch 19/20
172/172 [==============================] - 5s 25ms/step - loss: 0.8254
Epoch 20/20
172/172 [==============================] - 5s 25ms/step - loss: 0.7756

Generate text

The simplest way to generate text with this model is to run it in a loop, and keep track of the model's internal state as you execute it.

To generate text the model's output is fed back to the input

Each time you call the model you pass in some text and an internal state. The model returns a prediction for the next character and its new state. Pass the prediction and state back in to continue generating text.

The following makes a single step prediction:

class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

Run it in a loop to generate some text. Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.

start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)
ROMEO:
To whine First, Green, we shall go above
To sweet cossession; peace, you fortune's corse,
As is the bridegroom in the bosom of the
house, Lord Rober as he, on the wreck;
Therefore, all our souls by tume, when the old fige?
Father, become of George, fortune prophecy;
For, as thou hast warm'd, led bying how himself
concluded into the Tower into Laten and unsaside.
And by any hand that I, Jove forfoods.

KING RICHARD II:
Right tale: it is my man, wouldst remember who
Though Montague our than his pride.

OXFORD:
Come, come, when thou wert commander,
How thou mayness will be committed to a moleholds
me down and show your cancell'd,
It were to ussurage itself and so well:
Like to a longer death, make peers not himself
But sad you rascal. But 'twas a father
I should be gone affords: for thou will stay
fully.

First Servingman:
When, I both, and with it doth my oath,
I know him do marry Bolingbroke:
Why, she cames nightly of a tear.

POLIXENES:
Whiles I were it.

MERCUTIO:
Ay, but your whoras 

________________________________________________________________________________

Run time: 2.2399308681488037

The easiest thing you can do to improve the results is to train it for longer (try EPOCHS = 30).

You can also experiment with a different start string, try adding another RNN layer to improve the model's accuracy, or adjust the temperature parameter to generate more or less random predictions.

If you want the model to generate text faster the easiest thing you can do is batch the text generation. In the example below the model generates 5 outputs in about the same time it took to generate 1 above.

start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result, '\n\n' + '_'*80)
print('\nRun time:', end - start)
tf.Tensor(
[b"ROMEO:\nLet me say get there; but I dare done:\nFather, I beseech your grace, look up:\nThis dight I to lave from Axe,\nMay be tistal like an example, though before no other\nAn honourable carrion. Aufidius,\nWhich times my leviet, fate-water than stedlers,\nOf thy bending treasure is from succome.\nFear itself is father is farewell compliment\nEngland in which should be more grate.\nHark ye, so! There is no wooirly sail!\nHe was this cannot be. A pathy Varcius!\n\nCORIOLANUS:\nI dare not Catesby:\nThy father should be thought a thing is held!\n\nHORTENSIO:\nSaw doth he came aside his face stolen.\n3 KING HENRY VI\n\nFLORDINA:\nKing of Hermione,\nHeating this move of this place, which, still wind,\nBut to't, what with proud steed, was Corioli, I'll bring him speak.\nThis is the lead of all, and give it for the mind:\nHow please you, westable with your bremish,\nI nemed intended to cry, acquaint:\nwhen he goes, who have strong'd them, but they\ncan we rejouce thee more; and you were sedious,\nThou camest, my knows be! Whe"
 b"ROMEO:\nThis is Padis to hear a penitent toad:\nO, let me weep, and pardon, to bring thee to\nTen thousand chances with rain others,\nAnd bold as 'twere retemm'd into our bitts.\n\nKING RICHARD III:\nO sir, you are not angry thing! at his reputa\nMust die, my dear lord Edward's death!\n\nJULIET:\nThis shall desire to live or drunkes and\nto see it instantly to prison, and send tide to off.\nBut, what, so it hath need of twenty\nWere reprieve, to wapet all in loughs;\nHidest what of that which stock her pleaser,\nMade odds for bads to make a penpeture of this last,\nNo doubt not adopter-time, when an a\nhand ransom'd hands\nThat thus extremes you honestion: never\nsent fortunes for our tuble, and my soul to tar;\nFive ranks with coming sooth, as if the world,\nThou cousin's members: Richard, dear, do\nFlatter weep and see thee gone,\nAnd hath the dishonour'd informat of our kind.\n\nROMEO:\nUnhappy, visage, we will recreating than we fear.\n\nROMEO:\nThou canst not speak. Good motrow, house! Very come.\n\nHORTENSIO:\nThan da"
 b"ROMEO:\nStay awhile, Aumer, and whip: here's at land as any heaven?\n\nFirst Servant:\nHe's deceiven, O, welcome not;\nAnd so with sighs; the nobility\nof yok'd friend.'\nYou, sir, be mad--quativike.\n\nFirst Herand:\nHadst thou no cruel, to say the street'st\nThou hadst auded leasant wayde of hell\nAnd raise his name and make his seas and queen.\n\nISABELLA:\nO, good sir, she blessed with she.\n\nGRUMIO:\nAnd he shall lead forth any glory of mine eyes.\nIndeed, I do bend my smades and Jack; I am the city\nInterditation, safety from his mon;\nFor she rubs and how sent you: our ignorances\nAre flower. My peace will not shame; forswear it?\n\nPOMPEY:\nHarp! you here shall purchase of this state\nMades impossible how you shall.\n\nDUKE VINCENTIO:\nBy heaven, but thankful you were not at maid\nWhich we are mean-born back.\n\nLUCIO:\nSo shall you goes myself? O, let me hate,\nJive bonderous worships: here's the devil's join?\n\nFLORIZEL:\nMook and repent?\nIt strange inwered with the bastardy and the\ndaintings not to curse thine ear."
 b"ROMEO:\nThis state astent of supply stread'st,\nI wish you from disgrave to lay men innocent.\nHere in Vienna truly, by your company.\n\nHORTENSIO:\nI have said whose several coals.\n\nQUEEN ELIZABETH:\nBring your springs blest those that are word,\nif they have done, too fair, to dark to live.\n\nLEONTES:\nIt is sure as this?\n\nFirst Murderer:\nOf this extreme credit, tempted menbely!\n\nGREMIO:\nThe covering swells and makes himself wherein ever\nIn thy tongue bless our tender babes:\nAll payard, and Edward studied with wenches,\nBrows aids, I would not have you against your place.\nFather; yet my mother royal dyess\nOrder his project; but she shall be more\nTo stand these service, were they pined,--this is\nthe she's and unhappy, Henry, did Richard hearts\nOf all the rest: being true Rutland; but at last\nIn warrants not that shall do and child.\n\nWARWICK:\nNone: but to London will they are all frail.\n\nThird Gentleman:\nWhy to the poxing heart she?\n\nFirst Keeper:\nI' food somewhat is your grace command?\nWhat is it fool"
 b"ROMEO:\nThey shall not need comes my son Exeding here,\nAll shares all for aud: caption; can I not, I,\nThis 'leguner, how it is but clothes; but\nI mean our prison; 'twas tyrant brings\nLike shades and me to put in home, seen so\nno bound upon your happy hands into my head.\nIf; be gone in 'Coriolanus!\n\nMOPSA:\nIt shall be your would proceed:\nTools and fadians mann weep and singural fores\nTo you ado.\n\nFirst Soldier:\nWe promised well, Peter's Church his honour!\nI am made in safety.\n\nJULIET:\nAlmost helping then, among the Engelance;\nTo see it in the sky, but induced by great monumness,\nAnd begin himself and any fall to be to-morrow\nOr Aufidius, which hath to happy by his head;\nFor Edsurp at that e'er I cannot puep,\nWhom God hath sworn thanks and pluck'd them then;\nAnd what to fall, forthwith with a little worm a\nGain? that ever we come hither of our friends\nEdward more slight stolight: diad-mangeried by the whole\nVoucheth inhocestry and when the heels;\nTwelf, Sainol, Lucio, dost thou kneel'd? O heavy"], shape=(5,), dtype=string) 

________________________________________________________________________________

Run time: 2.115058660507202

Export the generator

This single-step model can easily be saved and restored, allowing you to use it anywhere a tf.saved_model is accepted.

tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.OneStep object at 0x7f12d45c8950>, because it is not built.
WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
WARNING:tensorflow:FOR KERAS USERS: The object that you are saving contains one or more Keras models or layers. If you are loading the SavedModel with `tf.keras.models.load_model`, continue reading (otherwise, you may ignore the following instructions). Please change your code to save with `tf.keras.models.save_model` or `model.save`, and confirm that the file "keras.metadata" exists in the export directory. In the future, Keras will only load the SavedModels that have this file. In other words, `tf.saved_model.save` will no longer write SavedModels that can be recovered as Keras models (this will apply in TF 2.5).

FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.
INFO:tensorflow:Assets written to: one_step/assets
INFO:tensorflow:Assets written to: one_step/assets
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(100):
  next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
  result.append(next_char)

print(tf.strings.join(result)[0].numpy().decode("utf-8"))
ROMEO:
Thus rid live form: then put were no lesser, young
Have been sour than she's enough, you're werew a

Advanced: Customized Training

The above training procedure is simple, but does not give you much control. It uses teacher-forcing which prevents bad predictions from being fed back to the model, so the model never learns to recover from mistakes.

So now that you've seen how to run the model manually next you'll implement the training loop. This gives a starting point if, for example, you want to implement curriculum learning to help stabilize the model's open-loop output.

The most important part of a custom training loop is the train step function.

Use tf.GradientTape to track the gradients. You can learn more about this approach by reading the eager execution guide.

The basic procedure is:

  1. Execute the model and calculate the loss under a tf.GradientTape.
  2. Calculate the updates and apply them to the model using the optimizer.
class CustomTraining(MyModel):
  @tf.function
  def train_step(self, inputs):
      inputs, labels = inputs
      with tf.GradientTape() as tape:
          predictions = self(inputs, training=True)
          loss = self.loss(labels, predictions)
      grads = tape.gradient(loss, model.trainable_variables)
      self.optimizer.apply_gradients(zip(grads, model.trainable_variables))

      return {'loss': loss}

The above implementation of the train_step method follows Keras' train_step conventions. This is optional, but it allows you to change the behavior of the train step and still use keras' Model.compile and Model.fit methods.

model = CustomTraining(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(dataset, epochs=1)
172/172 [==============================] - 7s 25ms/step - loss: 2.7120
<tensorflow.python.keras.callbacks.History at 0x7f12c6fb1b10>

Or if you need more control, you can write your own complete custom training loop:

EPOCHS = 10

mean = tf.metrics.Mean()

for epoch in range(EPOCHS):
    start = time.time()

    mean.reset_states()
    for (batch_n, (inp, target)) in enumerate(dataset):
        logs = model.train_step([inp, target])
        mean.update_state(logs['loss'])

        if batch_n % 50 == 0:
            template = f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}"
            print(template)

    # saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch))

    print()
    print(f'Epoch {epoch+1} Loss: {mean.result().numpy():.4f}')
    print(f'Time taken for 1 epoch {time.time() - start:.2f} sec')
    print("_"*80)

model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 2.2023
Epoch 1 Batch 50 Loss 2.0291
Epoch 1 Batch 100 Loss 1.9347
Epoch 1 Batch 150 Loss 1.8428

Epoch 1 Loss: 1.9860
Time taken for 1 epoch 5.76 sec
________________________________________________________________________________
Epoch 2 Batch 0 Loss 1.8406
Epoch 2 Batch 50 Loss 1.7190
Epoch 2 Batch 100 Loss 1.6493
Epoch 2 Batch 150 Loss 1.6515

Epoch 2 Loss: 1.7098
Time taken for 1 epoch 5.24 sec
________________________________________________________________________________
Epoch 3 Batch 0 Loss 1.6022
Epoch 3 Batch 50 Loss 1.5800
Epoch 3 Batch 100 Loss 1.5578
Epoch 3 Batch 150 Loss 1.5079

Epoch 3 Loss: 1.5490
Time taken for 1 epoch 5.22 sec
________________________________________________________________________________
Epoch 4 Batch 0 Loss 1.4763
Epoch 4 Batch 50 Loss 1.3982
Epoch 4 Batch 100 Loss 1.4232
Epoch 4 Batch 150 Loss 1.3946

Epoch 4 Loss: 1.4516
Time taken for 1 epoch 5.26 sec
________________________________________________________________________________
Epoch 5 Batch 0 Loss 1.3862
Epoch 5 Batch 50 Loss 1.3500
Epoch 5 Batch 100 Loss 1.4086
Epoch 5 Batch 150 Loss 1.3583

Epoch 5 Loss: 1.3848
Time taken for 1 epoch 5.46 sec
________________________________________________________________________________
Epoch 6 Batch 0 Loss 1.3203
Epoch 6 Batch 50 Loss 1.3539
Epoch 6 Batch 100 Loss 1.3335
Epoch 6 Batch 150 Loss 1.3703

Epoch 6 Loss: 1.3331
Time taken for 1 epoch 5.33 sec
________________________________________________________________________________
Epoch 7 Batch 0 Loss 1.2912
Epoch 7 Batch 50 Loss 1.2518
Epoch 7 Batch 100 Loss 1.3340
Epoch 7 Batch 150 Loss 1.3126

Epoch 7 Loss: 1.2878
Time taken for 1 epoch 5.26 sec
________________________________________________________________________________
Epoch 8 Batch 0 Loss 1.2496
Epoch 8 Batch 50 Loss 1.1985
Epoch 8 Batch 100 Loss 1.2753
Epoch 8 Batch 150 Loss 1.2512

Epoch 8 Loss: 1.2466
Time taken for 1 epoch 5.32 sec
________________________________________________________________________________
Epoch 9 Batch 0 Loss 1.1532
Epoch 9 Batch 50 Loss 1.2398
Epoch 9 Batch 100 Loss 1.1850
Epoch 9 Batch 150 Loss 1.2001

Epoch 9 Loss: 1.2064
Time taken for 1 epoch 5.30 sec
________________________________________________________________________________
Epoch 10 Batch 0 Loss 1.1392
Epoch 10 Batch 50 Loss 1.1360
Epoch 10 Batch 100 Loss 1.1426
Epoch 10 Batch 150 Loss 1.1648

Epoch 10 Loss: 1.1669
Time taken for 1 epoch 5.52 sec
________________________________________________________________________________