![]() |
![]() |
![]() |
![]() |
This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.
This tutorial includes runnable code implemented using tf.keras and eager execution. The following is sample output when the model in this tutorial trained for 30 epochs, and started with the prompt "Q":
QUEENE: I had thought thou hadst a Roman; for the oracle, Thus by All bids the man against the word, Which are so weak of care, by old care done; Your children were in your holy love, And the precipitation through the bleeding throne. BISHOP OF ELY: Marry, and will, my lord, to weep in such a one were prettiest; Yet now I was adopted heir Of the world's lamentable day, To watch the next way with his father with his face? ESCALUS: The cause why then we are all resolved more sons. VOLUMNIA: O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead, And love and pale as any will to that word. QUEEN ELIZABETH: But how long have I heard the soul for this world, And show his hands of life be proved to stand. PETRUCHIO: I say he look'd on, if I must be content To stay him from the fatal of our country's bliss. His lordship pluck'd from this sentence then for prey, And then let us twain, being the moon, were she such a case as fills m
While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:
The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.
The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.
As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.
Setup
Import TensorFlow and other libraries
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing
import numpy as np
import os
import time
Download the Shakespeare dataset
Change the following line to run this code on your own data.
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt 1122304/1115394 [==============================] - 0s 0us/step
Read the data
First, look in the text:
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print('Length of text: {} characters'.format(len(text)))
Length of text: 1115394 characters
# Take a look at the first 250 characters in text
print(text[:250])
First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people.
# The unique characters in the file
vocab = sorted(set(text))
print('{} unique characters'.format(len(vocab)))
65 unique characters
Process the text
Vectorize the text
Before training, you need to convert the strings to a numerical representation.
The preprocessing.StringLookup
layer can convert each character into a numeric ID. It just needs the text to be split into tokens first.
example_texts = ['abcdefg', 'xyz']
chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>
Now create the preprocessing.StringLookup
layer:
ids_from_chars = preprocessing.StringLookup(
vocabulary=list(vocab))
It converts form tokens to character IDs, padding with 0
:
ids = ids_from_chars(chars)
ids
<tf.RaggedTensor [[41, 42, 43, 44, 45, 46, 47], [64, 65, 66]]>
Since the goal of this tutorial is to generate text, it will also be important to invert this representation and recover human-readable strings from it. For this you can use preprocessing.StringLookup(..., invert=True)
.
chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=ids_from_chars.get_vocabulary(), invert=True)
This layer recovers the characters from the vectors of IDs, and returns them as a tf.RaggedTensor
of characters:
chars = chars_from_ids(ids)
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>
You can tf.strings.reduce_join
to join the characters back into strings.
tf.strings.reduce_join(chars, axis=-1).numpy()
array([b'abcdefg', b'xyz'], dtype=object)
def text_from_ids(ids):
return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)
The prediction task
Given a character, or a sequence of characters, what is the most probable next character? This is the task you're training the model to perform. The input to the model will be a sequence of characters, and you train the model to predict the output—the following character at each time step.
Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?
Create training examples and targets
Next divide the text into example sequences. Each input sequence will contain seq_length
characters from the text.
For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.
So break the text into chunks of seq_length+1
. For example, say seq_length
is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".
To do this first use the tf.data.Dataset.from_tensor_slices
function to convert the text vector into a stream of character indices.
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids
<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([20, 49, 58, ..., 47, 10, 2])>
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
print(chars_from_ids(ids).numpy().decode('utf-8'))
F i r s t C i t i
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)
The batch
method lets you easily convert these individual characters to sequences of the desired size.
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)
for seq in sequences.take(1):
print(chars_from_ids(seq))
tf.Tensor( [b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o' b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h' b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e' b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y' b'o' b'u' b' '], shape=(101,), dtype=string)
It's easier to see what this is doing if you join the tokens back into strings:
for seq in sequences.take(5):
print(text_from_ids(seq).numpy())
b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou ' b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k' b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki" b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d" b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'
For training you'll need a dataset of (input, label)
pairs. Where input
and
label
are sequences. At each time step the input is the current character and the label is the next character.
Here's a function that takes a sequence as input, duplicates, and shifts it to align the input and label for each timestep:
def split_input_target(sequence):
input_text = sequence[:-1]
target_text = sequence[1:]
return input_text, target_text
split_input_target(list("Tensorflow"))
(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'], ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
print("Input :", text_from_ids(input_example).numpy())
print("Target:", text_from_ids(target_example).numpy())
Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou' Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
Create training batches
You used tf.data
to split the text into manageable sequences. But before feeding this data into the model, you need to shuffle the data and pack it into batches.
# Batch size
BATCH_SIZE = 64
# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000
dataset = (
dataset
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
dataset
<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>
Build The Model
This section defines the model as a keras.Model
subclass (For details see Making new Layers and Models via subclassing).
This model has three layers:
tf.keras.layers.Embedding
: The input layer. A trainable lookup table that will map each character-ID to a vector withembedding_dim
dimensions;tf.keras.layers.GRU
: A type of RNN with sizeunits=rnn_units
(You can also use an LSTM layer here.)tf.keras.layers.Dense
: The output layer, withvocab_size
outputs. It outpts one logit for each character in the vocabulary. These are the log-liklihood of each character according to the model.
# Length of the vocabulary in chars
vocab_size = len(vocab)
# The embedding dimension
embedding_dim = 256
# Number of RNN units
rnn_units = 1024
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
model = MyModel(
# Be sure the vocabulary size matches the `StringLookup` layers.
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units)
For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character:
Try the model
Now run the model to see that it behaves as expected.
First check the shape of the output:
for input_example_batch, target_example_batch in dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 67) # (batch_size, sequence_length, vocab_size)
In the above example the sequence length of the input is 100
but the model can be run on inputs of any length:
model.summary()
Model: "my_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) multiple 17152 _________________________________________________________________ gru (GRU) multiple 3938304 _________________________________________________________________ dense (Dense) multiple 68675 ================================================================= Total params: 4,024,131 Trainable params: 4,024,131 Non-trainable params: 0 _________________________________________________________________
To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.
Try it for the first example in the batch:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()
This gives us, at each timestep, a prediction of the next character index:
sampled_indices
array([ 7, 57, 47, 49, 4, 11, 33, 22, 23, 45, 26, 4, 16, 40, 53, 32, 27, 6, 42, 11, 43, 50, 25, 13, 52, 37, 5, 3, 35, 50, 21, 18, 26, 55, 23, 30, 6, 49, 25, 52, 11, 45, 61, 6, 52, 42, 15, 57, 40, 31, 61, 18, 52, 18, 57, 15, 8, 17, 24, 34, 58, 57, 34, 50, 64, 53, 23, 52, 56, 26, 1, 63, 35, 35, 46, 57, 24, 35, 20, 49, 31, 15, 11, 52, 41, 20, 45, 44, 50, 48, 59, 60, 46, 3, 5, 48, 28, 4, 64, 57])
Decode these to see the text predicted by this untrained model:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())
Input: b'es here?\n\nNORTHUMBERLAND:\nIt is my son, young Harry Percy,\nSent from my brother Worcester, whencesoe' Next Char Predictions: b"'qgi!3SHIeL!BZmRM&b3cjK;lW$ UjGDLoIP&iKl3eu&lbAqZQuDlDqA,CJTrqTjxmIlpL[UNK]wUUfqJUFiQA3laFedjhstf $hN!xq"
Train the model
At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character.
Attach an optimizer, and a loss function
The standard tf.keras.losses.sparse_categorical_crossentropy
loss function works in this case because it is applied across the last dimension of the predictions.
Because your model returns logits, you need to set the from_logits
flag.
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss: ", mean_loss)
Prediction shape: (64, 100, 67) # (batch_size, sequence_length, vocab_size) Mean loss: 4.2050414
A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:
tf.exp(mean_loss).numpy()
67.02338
Configure the training procedure using the tf.keras.Model.compile
method. Use tf.keras.optimizers.Adam
with default arguments and the loss function.
model.compile(optimizer='adam', loss=loss)
Configure checkpoints
Use a tf.keras.callbacks.ModelCheckpoint
to ensure that checkpoints are saved during training:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
Execute the training
To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training.
EPOCHS = 20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/20 172/172 [==============================] - 7s 26ms/step - loss: 3.2936 Epoch 2/20 172/172 [==============================] - 5s 25ms/step - loss: 2.0830 Epoch 3/20 172/172 [==============================] - 5s 25ms/step - loss: 1.7665 Epoch 4/20 172/172 [==============================] - 5s 25ms/step - loss: 1.5841 Epoch 5/20 172/172 [==============================] - 5s 26ms/step - loss: 1.4680 Epoch 6/20 172/172 [==============================] - 6s 26ms/step - loss: 1.3950 Epoch 7/20 172/172 [==============================] - 5s 26ms/step - loss: 1.3378 Epoch 8/20 172/172 [==============================] - 5s 26ms/step - loss: 1.2896 Epoch 9/20 172/172 [==============================] - 5s 26ms/step - loss: 1.2472 Epoch 10/20 172/172 [==============================] - 5s 25ms/step - loss: 1.2054 Epoch 11/20 172/172 [==============================] - 6s 26ms/step - loss: 1.1654 Epoch 12/20 172/172 [==============================] - 6s 26ms/step - loss: 1.1214 Epoch 13/20 172/172 [==============================] - 5s 26ms/step - loss: 1.0743 Epoch 14/20 172/172 [==============================] - 5s 26ms/step - loss: 1.0294 Epoch 15/20 172/172 [==============================] - 5s 25ms/step - loss: 0.9801 Epoch 16/20 172/172 [==============================] - 5s 25ms/step - loss: 0.9275 Epoch 17/20 172/172 [==============================] - 5s 25ms/step - loss: 0.8741 Epoch 18/20 172/172 [==============================] - 5s 25ms/step - loss: 0.8177 Epoch 19/20 172/172 [==============================] - 5s 25ms/step - loss: 0.7643 Epoch 20/20 172/172 [==============================] - 5s 25ms/step - loss: 0.7131
Generate text
The simplest way to generate text with this model is to run it in a loop, and keep track of the model's internal state as you execute it.
Each time you call the model you pass in some text and an internal state. The model returns a prediction for the next character and its new state. Pass the prediction and state back in to continue generating text.
The following makes a single step prediction:
class OneStep(tf.keras.Model):
def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
super().__init__()
self.temperature=temperature
self.model = model
self.chars_from_ids = chars_from_ids
self.ids_from_chars = ids_from_chars
# Create a mask to prevent "" or "[UNK]" from being generated.
skip_ids = self.ids_from_chars(['','[UNK]'])[:, None]
sparse_mask = tf.SparseTensor(
# Put a -inf at each bad index.
values=[-float('inf')]*len(skip_ids),
indices = skip_ids,
# Match the shape to the vocabulary
dense_shape=[len(ids_from_chars.get_vocabulary())])
self.prediction_mask = tf.sparse.to_dense(sparse_mask)
@tf.function
def generate_one_step(self, inputs, states=None):
# Convert strings to token IDs.
input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
input_ids = self.ids_from_chars(input_chars).to_tensor()
# Run the model.
# predicted_logits.shape is [batch, char, next_char_logits]
predicted_logits, states = self.model(inputs=input_ids, states=states,
return_state=True)
# Only use the last prediction.
predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/self.temperature
# Apply the prediction mask: prevent "" or "[UNK]" from being generated.
predicted_logits = predicted_logits + self.prediction_mask
# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)
# Convert from token ids to characters
predicted_chars = self.chars_from_ids(predicted_ids)
# Return the characters and model state.
return predicted_chars, states
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)
Run it in a loop to generate some text. Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.
start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]
for n in range(1000):
next_char, states = one_step_model.generate_one_step(next_char, states=states)
result.append(next_char)
result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print(f"\nRun time: {end - start}")
ROMEO: That's the wars. BAPTISTA: Ay, but what of Poxers? BIANCA: That's as might we part; These times the waning of women'st cares Against the heart. There's nothing be more of the house of some more friends to the crown his face, And hide as perigon'd bawd in life, My parks hear me with peace with secrets. But curst it were nothing, tutles; blessed blood and banishment: The comes or ourselves high happy half, hably born Blueping, storious passage friends: I create the curps and word proportions: Tire trouble on my kin, more dark and clapp'd out for His beauty's point of woman and weep, when I see it best me as their absence. Here on thy hand, and that sorrow danger Which thou hast disposed the queen my space; And undertake the truth, or thou art bitter With paper heaven and worth all dissels And read the slain: And live abrogg am It: And in this trouble drops rebellion, I do not make: I say shruw with all; all ill in blood, The world is strain'd, and brushing advers' taste They valiant, ________________________________________________________________________________ Run time: 2.4661831855773926
The easiest thing you can do to improve the results is to train it for longer (try EPOCHS = 30
).
You can also experiment with a different start string, try adding another RNN layer to improve the model's accuracy, or adjust the temperature parameter to generate more or less random predictions.
If you want the model to generate text faster the easiest thing you can do is batch the text generation. In the example below the model generates 5 outputs in about the same time it took to generate 1 above.
start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]
for n in range(1000):
next_char, states = one_step_model.generate_one_step(next_char, states=states)
result.append(next_char)
result = tf.strings.join(result)
end = time.time()
print(result, '\n\n' + '_'*80)
print(f"\nRun time: {end - start}")
tf.Tensor( [b"ROMEO:\nThe heads of the good will each any other haboch\nBy his usmove bosom!' quoth he,\nGo see his issue: wish thy cause requitths\nOf knowledged want of wicked, cryom!\nA part! call him, as great as heaven she is\nthe best: something her about hear, have hearded thee\nManquish'd by their offices of his father.\nUpor me, on my heart;\nI have reason that I infected him,\nHis realm at me as Heralfor\nOf Naples Chear'd zeap walour in the consul.\n\nCORIOLANUS:\nMore given and gruck,\nFor then my lord and following the Lordshop's malice,\nWhen thou wilt sea-more of this new-man's tormon prosper,\nTo be mine own, my most grace tappiness;\nLet me receive the crown, stand gentle thee\nFrom thine say false glound of the chairs of Richard.\n\nTRANIO:\nAnd yes, I tell thee, fellowing love.\n\nBAPTISTA:\nBut how our vastard is become a bare.\nBut I shall hope he will seem to dew\nOr whether he is lesson'd by the yeat of man.\nBut now I love a bontagner,\nA one that shall stay the wrencher, are is very hope:\n'Twere perpecul.\n\nCO" b"ROMEO:\nWhy fares, thou speak'st not adversmance a tap-ervowity?\n\nCLARENCE:\nA man affliction, raining leng-hing: let me speak\nWhich makes his sweet as anothen sovered.\nGive me thy hand, this gentleman is absent. Sir, I have\nheard and sole fortune's bloods to the people,\nWe shall not give me France, and desire to me so.\nLady courteous vextard both my heart,\nUpon the indictaty, of your command,\nAnd that did break the strength of heaven sentenced; he will come\nPerlance of Gloucester here; or else Sethix me\nAnd, mistress Barour, first, his body of a good\nfrown, it disgraced, can change the roin,\nMethinks a worthy discourse blood created!\nWhy art thou offended, my daughter? was the georle arm\nme that is pomp to harmour on a Carest,\nWhat do I fear the aiting deed, that will defend thee,\nFor substance in the cure of all\nwas blood aside; his punishment begin.\n\nBRUTUS:\nYou must conse:\nwill you go sleep Bunkingham back?\n\nLADY CAPULET:\nHe to a horses, why I desired\nIn me it hath a fentremies of men,\nAnd" b"ROMEO:\nThen must I combint to be much commons;\nWhich thou some state was he that can yield with the\ncreature in heaven; he cried, farewell; I mean to sue:\nBut after our faults what offence hither\nYou dance. What will meet me may suffrian coalst thy life?\n\nGLOUCESTER:\nAy, so Isabel? I bige my feet Ladow:\nTemptantion, that shall speak of prenamosa's son\nWhich shapeling loss or law;\nAnd send her day did current for me.\n3 KING HENRY VI\n\nFirst Keeper:\nThis did you deny to home? by my trespassfit to you,\nis thy art for thin.\n\nHENRY BOLINGBROKE:\nO, who is time woe a fool well delight,\nHow the poor court and time true, it\nwith deadly perish.\n\nLORD FIT:\nTybalt, what means this blood strikes made thee from thy trick:\nBut in the senate-house, wilt thou kneel\nWhen Capulett; and fearfully stand,\nAnd in this cove and lip at fires;\nFor doing worthy valour in your calling.\nI must confess, repeal a pulse\nBut not A thing that every dark.\n\nKING RICHARD III:\nBeauty hath touch'd away his last,\nThe duke, as from " b"ROMEO:\nShould I do so I promised to Crysicial:\nA favour with your sister. This leanous depresens\nMaster here on peace and dancing cheer;\nWhat would he slare his full women word.\nThen, till the strong star, when he did love him.\n\nWARWICK:\nThey will I with the warlike aight,\nAnd fit to prison have war, rather than death:\nBut if it perceive in a true man's\nears to what they' made faces boast of joy\nAnd kiss on him that list taste of the world;\nWe at tod you a consul, I ward\nCome; your own danger be prodectood\nMay plainly that fashions and loves as lief,\nSha, we twenty thousand changes of peace.\n\nBUCKINGHAM:\nWith these even blows whose honourable tents,\nMurder the crown, and, in a beggar's blood:\nThis blood lady cried 'twere water to to visit you. I, tell me, gentle matter for\nClaudio, to be heard, but it do me thieves;\nWith Bolingbroke depending, how hell is kerm,\nBy that, or die to-morrow 'tears me.\nAlack, and death, can shame thee for a child;\nDid gentle Cleave are the rest; O, the ship spiri" b"ROMEO:\nGod-death, nor day to-morrow, then applay\nA little shame is preporting up;\nAnd 'tis no married traitors and to speak;\nAnd, after man, there will to quench.\n\nWARWICK:\nFit, at that trade, is there a month or bides\nIn the main beam of common shade.\nWhat if thou, Errolaw,\nHe'll take the truth: but winty inacentain\nMay presently in the deep-fellow:' quoth he,\nWith close my spoke and death you make, and they\nShall have our covertable state with beauty's son.\n\nJULIET:\nThis blood spites, his grace but more than common\nOf good favour and as much as friends\nAs wicked dishings proclaim'd and scorn'd low wings.\nAJemery, my memberial damned Clarence,\nOur ambition gentlemen was far\nTo gros his friends at the cross--\nA parcel of Gloucester, at Dubbes of King Richard;\nWhere thy quarrel bodies Bianca,\nAnd thrice pawning,--\nButtis blood the blood about it in my speech,\nWhich all the torse that were the citizens,\nYet rather know I please my life.\n\nGLOUCESTER:\nCome, sirrah! keep you not unkinkly bitter.\n"], shape=(5,), dtype=string) ________________________________________________________________________________ Run time: 2.235074043273926
Export the generator
This single-step model can easily be saved and restored, allowing you to use it anywhere a tf.saved_model
is accepted.
tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.OneStep object at 0x7f7eb0d04ac8>, because it is not built. Warning:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading. WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: one_step/assets INFO:tensorflow:Assets written to: one_step/assets
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]
for n in range(100):
next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
result.append(next_char)
print(tf.strings.join(result)[0].numpy().decode("utf-8"))
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7f7ddc1b4158> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. Warning:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7f7ddc1b4158> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. Warning:tensorflow:6 out of the last 6 calls to <function recreate_function.<locals>.restored_function_body at 0x7f7ddc1b4158> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. Warning:tensorflow:6 out of the last 6 calls to <function recreate_function.<locals>.restored_function_body at 0x7f7ddc1b4158> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. ROMEO: The sun and bid his mortal back again To search, and keep we shall edward thee From this most study
Advanced: Customized Training
The above training procedure is simple, but does not give you much control. It uses teacher-forcing which prevents bad predictions from being fed back to the model so the model never learns to recover from mistakes.
So now that you've seen how to run the model manually next you'll implement the training loop. This gives a starting point if, for example, you want to implement curriculum learning to help stabilize the model's open-loop output.
The most important part of a custom training loop is the train step function.
Use tf.GradientTape
to track the gradients. You can learn more about this approach by reading the eager execution guide.
The basic procedure is:
- Execute the model and calculate the loss under a
tf.GradientTape
. - Calculate the updates and apply them to the model using the optimizer.
class CustomTraining(MyModel):
@tf.function
def train_step(self, inputs):
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True)
loss = self.loss(labels, predictions)
grads = tape.gradient(loss, model.trainable_variables)
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
return {'loss': loss}
The above implementation of the train_step
method follows Keras' train_step
conventions. This is optional, but it allows you to change the behavior of the train step and still use keras' Model.compile
and Model.fit
methods.
model = CustomTraining(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units)
model.compile(optimizer = tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(dataset, epochs=1)
172/172 [==============================] - 15s 76ms/step - loss: 2.7008 <tensorflow.python.keras.callbacks.History at 0x7f7ddc2ef668>
Or if you need more control, you can write your own complete custom training loop:
EPOCHS = 10
mean = tf.metrics.Mean()
for epoch in range(EPOCHS):
start = time.time()
mean.reset_states()
for (batch_n, (inp, target)) in enumerate(dataset):
logs = model.train_step([inp, target])
mean.update_state(logs['loss'])
if batch_n % 50 == 0:
template = 'Epoch {} Batch {} Loss {}'
print(template.format(epoch + 1, batch_n, logs['loss']))
# saving (checkpoint) the model every 5 epochs
if (epoch + 1) % 5 == 0:
model.save_weights(checkpoint_prefix.format(epoch=epoch))
print()
print('Epoch {} Loss: {:.4f}'.format(epoch + 1, mean.result().numpy()))
print('Time taken for 1 epoch {} sec'.format(time.time() - start))
print("_"*80)
model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 2.1587090492248535 Epoch 1 Batch 50 Loss 2.015890121459961 Epoch 1 Batch 100 Loss 1.8815938234329224 Epoch 1 Batch 150 Loss 1.8182443380355835 Epoch 1 Loss: 1.9731 Time taken for 1 epoch 14.837902069091797 sec ________________________________________________________________________________ Epoch 2 Batch 0 Loss 1.801224946975708 Epoch 2 Batch 50 Loss 1.7420837879180908 Epoch 2 Batch 100 Loss 1.6781731843948364 Epoch 2 Batch 150 Loss 1.6495245695114136 Epoch 2 Loss: 1.6969 Time taken for 1 epoch 14.266141414642334 sec ________________________________________________________________________________ Epoch 3 Batch 0 Loss 1.6056054830551147 Epoch 3 Batch 50 Loss 1.557340145111084 Epoch 3 Batch 100 Loss 1.476023554801941 Epoch 3 Batch 150 Loss 1.5442909002304077 Epoch 3 Loss: 1.5393 Time taken for 1 epoch 14.16746735572815 sec ________________________________________________________________________________ Epoch 4 Batch 0 Loss 1.4498177766799927 Epoch 4 Batch 50 Loss 1.4632078409194946 Epoch 4 Batch 100 Loss 1.4920791387557983 Epoch 4 Batch 150 Loss 1.3913884162902832 Epoch 4 Loss: 1.4423 Time taken for 1 epoch 13.796212196350098 sec ________________________________________________________________________________ Epoch 5 Batch 0 Loss 1.3514453172683716 Epoch 5 Batch 50 Loss 1.3973870277404785 Epoch 5 Batch 100 Loss 1.3572988510131836 Epoch 5 Batch 150 Loss 1.395420789718628 Epoch 5 Loss: 1.3756 Time taken for 1 epoch 13.89029860496521 sec ________________________________________________________________________________ Epoch 6 Batch 0 Loss 1.3473232984542847 Epoch 6 Batch 50 Loss 1.2732903957366943 Epoch 6 Batch 100 Loss 1.3224128484725952 Epoch 6 Batch 150 Loss 1.303688645362854 Epoch 6 Loss: 1.3241 Time taken for 1 epoch 14.103551626205444 sec ________________________________________________________________________________ Epoch 7 Batch 0 Loss 1.2467150688171387 Epoch 7 Batch 50 Loss 1.3191982507705688 Epoch 7 Batch 100 Loss 1.301483154296875 Epoch 7 Batch 150 Loss 1.269212007522583 Epoch 7 Loss: 1.2787 Time taken for 1 epoch 13.593403339385986 sec ________________________________________________________________________________ Epoch 8 Batch 0 Loss 1.2092283964157104 Epoch 8 Batch 50 Loss 1.253741979598999 Epoch 8 Batch 100 Loss 1.1954262256622314 Epoch 8 Batch 150 Loss 1.2262048721313477 Epoch 8 Loss: 1.2386 Time taken for 1 epoch 13.393027544021606 sec ________________________________________________________________________________ Epoch 9 Batch 0 Loss 1.1826077699661255 Epoch 9 Batch 50 Loss 1.1799037456512451 Epoch 9 Batch 100 Loss 1.2302446365356445 Epoch 9 Batch 150 Loss 1.2093273401260376 Epoch 9 Loss: 1.1993 Time taken for 1 epoch 13.090997695922852 sec ________________________________________________________________________________ Epoch 10 Batch 0 Loss 1.1227319240570068 Epoch 10 Batch 50 Loss 1.1515312194824219 Epoch 10 Batch 100 Loss 1.1636242866516113 Epoch 10 Batch 150 Loss 1.210679054260254 Epoch 10 Loss: 1.1597 Time taken for 1 epoch 13.395586729049683 sec ________________________________________________________________________________