![]() |
![]() |
![]() |
![]() |
This text classification tutorial trains a recurrent neural network on the IMDB large movie review dataset for sentiment analysis.
Setup
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
Import matplotlib
and create a helper function to plot graphs:
import matplotlib.pyplot as plt
def plot_graphs(history, metric):
plt.plot(history.history[metric])
plt.plot(history.history['val_'+metric], '')
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend([metric, 'val_'+metric])
Setup input pipeline
The IMDB large movie review dataset is a binary classification dataset—all the reviews have either a positive or negative sentiment.
Download the dataset using TFDS. See the loading text tutorial for details on how to load this sort of data manually.
dataset, info = tfds.load('imdb_reviews', with_info=True,
as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
train_dataset.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))
Initially this returns a dataset of (text, label pairs):
for example, label in train_dataset.take(1):
print('text: ', example.numpy())
print('label: ', label.numpy())
text: b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it." label: 0
Next shuffle the data for training and create batches of these (text, label)
pairs:
BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
for example, label in train_dataset.take(1):
print('texts: ', example.numpy()[:3])
print()
print('labels: ', label.numpy()[:3])
texts: [b'I was raised in a "very Christian" household since birth. I was saved before I saw this movie and the rest of the series and was forced to watch it in a youth group at my church. This movie was highly disturbing. I saw it when I was about 12 years old and literally had nightmares about it for years. I used to lay awake in bed and listen for the sounds of my mom\'s footsteps upstairs. If I didn\'t hear her footsteps, I would sneak upstairs to make sure she hadn\'t been raptured. I used to pray so hard every night for salvation because I was terrified of Jesus forgetting me. This is definitely not something I will show to my kids until they are much older, if at all. It took me years to shake the fear that this movie gave me.' b"Many of the American people would say...What??? to my opening comment. Yes I know that my comparison is without doubts an insult for the fans of the Master Akira Kurosawa, but if you analyze this movie, my comment is right. We have the peasant who goes to the town searching for help against a band of grasshoppers who wants to steal the harvest of the village. The great difference is the way that the story takes. Our samurais, a band of circus performers as in the original are a very complex mixture of personalities but at the end are what the village needs, HEROES. Please watch again this incredible movie (the Seven Samurai, obviously) and find another movies who has stolen the story and tried to get the same magic effect than the Masterpiece of Akira Kurosawa. A tip is The 13th Warrior with Antonio Banderas, Michael Crichton copied the story to wrote his Best seller's, but he didn't found the third foot of the cat." b'Overshadowed by "Braveheart" released the same year, the two costume dramas beg comparison. I admit my bias against Mel Gibson, yet I maintain a rational preference for "Rob Roy." Both "Braveheart" and "Rob Roy" compellingly depict Scots history in bloody, romantic fashion. "Braveheart" is an epic paean to individual honor and courage and a fine revenge fantasy. It\'s also melodramatic, anachronistic and maudlin. Note its cornball usage of slow motion filming. Its violence is both ugly and glorious. It is the latter quality which makes it more appealing to the adolescent mindset. While "Braveheart" surpasses "Rob Roy" in sheer levels of carnage (not to mention its indulgent running time), the latter film is ultimately more mature and satisfying. Its action is more understated, yet more surprising and clever. Its sex is less showy, yet more erotic. "Rob Roy" also has a better realized romantic interest. Its dialog attempts to approximate the poetry of the period. Its rotted teeth in the mouths of the actors attempt to approximate the dentistry of the era. And Tim Roth is a superlative villain. Also recommended: "The Last of the Mohicans" and "The Patriot." You may find the latter more akin to "Braveheart" with its emphasis on blood lust, with the former more similar to "Rob Roy" in tone. All the of the aforementioned movies merit their R ratings for violence.'] labels: [0 1 1]
Create the text encoder
The raw text loaded by tfds
needs to be processed before it can be used in a model. The simplest way to process text for training is using the TextVectorization
layer. This layer has many capabilities, but this tutorial sticks to the default behavior.
Create the layer, and pass the dataset's text to the layer's .adapt
method:
VOCAB_SIZE = 1000
encoder = tf.keras.layers.TextVectorization(
max_tokens=VOCAB_SIZE)
encoder.adapt(train_dataset.map(lambda text, label: text))
The .adapt
method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency:
vocab = np.array(encoder.get_vocabulary())
vocab[:20]
array(['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i', 'this', 'that', 'br', 'was', 'as', 'for', 'with', 'movie', 'but'], dtype='<U14')
Once the vocabulary is set, the layer can encode text into indices. The tensors of indices are 0-padded to the longest sequence in the batch (unless you set a fixed output_sequence_length
):
encoded_example = encoder(example)[:3].numpy()
encoded_example
array([[ 10, 14, 1, ..., 0, 0, 0], [106, 5, 2, ..., 0, 0, 0], [ 1, 33, 1, ..., 0, 0, 0]])
With the default settings, the process is not completely reversible. There are three main reasons for that:
- The default value for
preprocessing.TextVectorization
'sstandardize
argument is"lower_and_strip_punctuation"
. - The limited vocabulary size and lack of character-based fallback results in some unknown tokens.
for n in range(3):
print("Original: ", example[n].numpy())
print("Round-trip: ", " ".join(vocab[encoded_example[n]]))
print()
Original: b'I was raised in a "very Christian" household since birth. I was saved before I saw this movie and the rest of the series and was forced to watch it in a youth group at my church. This movie was highly disturbing. I saw it when I was about 12 years old and literally had nightmares about it for years. I used to lay awake in bed and listen for the sounds of my mom\'s footsteps upstairs. If I didn\'t hear her footsteps, I would sneak upstairs to make sure she hadn\'t been raptured. I used to pray so hard every night for salvation because I was terrified of Jesus forgetting me. This is definitely not something I will show to my kids until they are much older, if at all. It took me years to shake the fear that this movie gave me.' Round-trip: i was [UNK] in a very [UNK] [UNK] since [UNK] i was [UNK] before i saw this movie and the rest of the series and was forced to watch it in a [UNK] group at my [UNK] this movie was highly [UNK] i saw it when i was about [UNK] years old and [UNK] had [UNK] about it for years i used to [UNK] [UNK] in [UNK] and [UNK] for the sounds of my [UNK] [UNK] [UNK] if i didnt hear her [UNK] i would [UNK] [UNK] to make sure she [UNK] been [UNK] i used to [UNK] so hard every night for [UNK] because i was [UNK] of [UNK] [UNK] me this is definitely not something i will show to my kids until they are much older if at all it took me years to [UNK] the [UNK] that this movie gave me Original: b"Many of the American people would say...What??? to my opening comment. Yes I know that my comparison is without doubts an insult for the fans of the Master Akira Kurosawa, but if you analyze this movie, my comment is right. We have the peasant who goes to the town searching for help against a band of grasshoppers who wants to steal the harvest of the village. The great difference is the way that the story takes. Our samurais, a band of circus performers as in the original are a very complex mixture of personalities but at the end are what the village needs, HEROES. Please watch again this incredible movie (the Seven Samurai, obviously) and find another movies who has stolen the story and tried to get the same magic effect than the Masterpiece of Akira Kurosawa. A tip is The 13th Warrior with Antonio Banderas, Michael Crichton copied the story to wrote his Best seller's, but he didn't found the third foot of the cat." Round-trip: many of the american people would [UNK] to my opening comment yes i know that my [UNK] is without [UNK] an [UNK] for the fans of the [UNK] [UNK] [UNK] but if you [UNK] this movie my comment is right we have the [UNK] who goes to the town [UNK] for help against a [UNK] of [UNK] who wants to [UNK] the [UNK] of the [UNK] the great [UNK] is the way that the story takes our [UNK] a [UNK] of [UNK] [UNK] as in the original are a very [UNK] [UNK] of [UNK] but at the end are what the [UNK] needs [UNK] please watch again this [UNK] movie the [UNK] [UNK] obviously and find another movies who has [UNK] the story and tried to get the same [UNK] effect than the [UNK] of [UNK] [UNK] a [UNK] is the [UNK] [UNK] with [UNK] [UNK] michael [UNK] [UNK] the story to [UNK] his best [UNK] but he didnt found the third [UNK] of the [UNK] Original: b'Overshadowed by "Braveheart" released the same year, the two costume dramas beg comparison. I admit my bias against Mel Gibson, yet I maintain a rational preference for "Rob Roy." Both "Braveheart" and "Rob Roy" compellingly depict Scots history in bloody, romantic fashion. "Braveheart" is an epic paean to individual honor and courage and a fine revenge fantasy. It\'s also melodramatic, anachronistic and maudlin. Note its cornball usage of slow motion filming. Its violence is both ugly and glorious. It is the latter quality which makes it more appealing to the adolescent mindset. While "Braveheart" surpasses "Rob Roy" in sheer levels of carnage (not to mention its indulgent running time), the latter film is ultimately more mature and satisfying. Its action is more understated, yet more surprising and clever. Its sex is less showy, yet more erotic. "Rob Roy" also has a better realized romantic interest. Its dialog attempts to approximate the poetry of the period. Its rotted teeth in the mouths of the actors attempt to approximate the dentistry of the era. And Tim Roth is a superlative villain. Also recommended: "The Last of the Mohicans" and "The Patriot." You may find the latter more akin to "Braveheart" with its emphasis on blood lust, with the former more similar to "Rob Roy" in tone. All the of the aforementioned movies merit their R ratings for violence.' Round-trip: [UNK] by [UNK] released the same year the two [UNK] [UNK] [UNK] [UNK] i admit my [UNK] against [UNK] [UNK] yet i [UNK] a [UNK] [UNK] for [UNK] [UNK] both [UNK] and [UNK] [UNK] [UNK] [UNK] [UNK] history in [UNK] romantic [UNK] [UNK] is an [UNK] [UNK] to [UNK] [UNK] and [UNK] and a fine [UNK] fantasy its also [UNK] [UNK] and [UNK] note its [UNK] [UNK] of slow [UNK] [UNK] its violence is both [UNK] and [UNK] it is the [UNK] quality which makes it more [UNK] to the [UNK] [UNK] while [UNK] [UNK] [UNK] [UNK] in [UNK] [UNK] of [UNK] not to mention its [UNK] running time the [UNK] film is [UNK] more [UNK] and [UNK] its action is more [UNK] yet more [UNK] and [UNK] its sex is less [UNK] yet more [UNK] [UNK] [UNK] also has a better [UNK] romantic interest its dialog attempts to [UNK] the [UNK] of the period its [UNK] [UNK] in the [UNK] of the actors attempt to [UNK] the [UNK] of the [UNK] and [UNK] [UNK] is a [UNK] [UNK] also [UNK] the last of the [UNK] and the [UNK] you may find the [UNK] more [UNK] to [UNK] with its [UNK] on blood [UNK] with the [UNK] more similar to [UNK] [UNK] in [UNK] all the of the [UNK] movies [UNK] their [UNK] [UNK] for violence
Create the model
Above is a diagram of the model.
This model can be build as a
tf.keras.Sequential
.The first layer is the
encoder
, which converts the text to a sequence of token indices.After the encoder is an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.
This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a
tf.keras.layers.Dense
layer.A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input on the next timestep.
The
tf.keras.layers.Bidirectional
wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the final output.The main advantage of a bidirectional RNN is that the signal from the beginning of the input doesn't need to be processed all the way through every timestep to affect the output.
The main disadvantage of a bidirectional RNN is that you can't efficiently stream predictions as words are being added to the end.
After the RNN has converted the sequence to a single vector the two
layers.Dense
do some final processing, and convert from this vector representation to a single logit as the classification output.
The code to implement this is below:
model = tf.keras.Sequential([
encoder,
tf.keras.layers.Embedding(
input_dim=len(encoder.get_vocabulary()),
output_dim=64,
# Use masking to handle the variable sequence lengths
mask_zero=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
Please note that Keras sequential model is used here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check Keras RNN guide for more details.
The embedding layer uses masking to handle the varying sequence-lengths. All the layers after the Embedding
support masking:
print([layer.supports_masking for layer in model.layers])
[False, True, True, True, True]
To confirm that this works as expected, evaluate a sentence twice. First, alone so there's no padding to mask:
# predict on a sample text without padding.
sample_text = ('The movie was cool. The animation and the graphics '
'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions[0])
[-0.00106721]
Now, evaluate it again in a batch with a longer sentence. The result should be identical:
# predict on a sample text with padding
padding = "the " * 2000
predictions = model.predict(np.array([sample_text, padding]))
print(predictions[0])
[-0.0010672]
Compile the Keras model to configure the training process:
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
Train the model
history = model.fit(train_dataset, epochs=10,
validation_data=test_dataset,
validation_steps=30)
Epoch 1/10 391/391 [==============================] - 43s 86ms/step - loss: 0.6528 - accuracy: 0.5573 - val_loss: 0.4968 - val_accuracy: 0.7349 Epoch 2/10 391/391 [==============================] - 30s 76ms/step - loss: 0.4547 - accuracy: 0.7858 - val_loss: 0.4449 - val_accuracy: 0.8245 Epoch 3/10 391/391 [==============================] - 32s 80ms/step - loss: 0.4006 - accuracy: 0.8239 - val_loss: 0.3702 - val_accuracy: 0.8438 Epoch 4/10 391/391 [==============================] - 31s 76ms/step - loss: 0.3441 - accuracy: 0.8528 - val_loss: 0.3450 - val_accuracy: 0.8557 Epoch 5/10 391/391 [==============================] - 31s 75ms/step - loss: 0.3254 - accuracy: 0.8614 - val_loss: 0.3418 - val_accuracy: 0.8625 Epoch 6/10 391/391 [==============================] - 30s 74ms/step - loss: 0.3153 - accuracy: 0.8668 - val_loss: 0.3265 - val_accuracy: 0.8552 Epoch 7/10 391/391 [==============================] - 31s 75ms/step - loss: 0.3075 - accuracy: 0.8700 - val_loss: 0.3228 - val_accuracy: 0.8620 Epoch 8/10 391/391 [==============================] - 30s 74ms/step - loss: 0.3022 - accuracy: 0.8713 - val_loss: 0.3259 - val_accuracy: 0.8542 Epoch 9/10 391/391 [==============================] - 30s 74ms/step - loss: 0.3008 - accuracy: 0.8727 - val_loss: 0.3189 - val_accuracy: 0.8630 Epoch 10/10 391/391 [==============================] - 30s 75ms/step - loss: 0.2972 - accuracy: 0.8736 - val_loss: 0.3240 - val_accuracy: 0.8641
test_loss, test_acc = model.evaluate(test_dataset)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 14s 35ms/step - loss: 0.3219 - accuracy: 0.8642 Test Loss: 0.32192152738571167 Test Accuracy: 0.8641999959945679
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.ylim(None, 1)
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
plt.ylim(0, None)
(0.0, 0.6705824166536332)
Run a prediction on a new sentence:
If the prediction is >= 0.0, it is positive else it is negative.
sample_text = ('The movie was cool. The animation and the graphics '
'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
Stack two or more LSTM layers
Keras recurrent layers have two available modes that are controlled by the return_sequences
constructor argument:
If
False
it returns only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)). This is the default, used in the previous model.If
True
the full sequences of successive outputs for each timestep is returned (a 3D tensor of shape(batch_size, timesteps, output_features)
).
Here is what the flow of information looks like with return_sequences=True
:
The interesting thing about using an RNN
with return_sequences=True
is that the output still has 3-axes, like the input, so it can be passed to another RNN layer, like this:
model = tf.keras.Sequential([
encoder,
tf.keras.layers.Embedding(len(encoder.get_vocabulary()), 64, mask_zero=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1)
])
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10,
validation_data=test_dataset,
validation_steps=30)
Epoch 1/10 391/391 [==============================] - 71s 145ms/step - loss: 0.6166 - accuracy: 0.6006 - val_loss: 0.4270 - val_accuracy: 0.8125 Epoch 2/10 391/391 [==============================] - 52s 131ms/step - loss: 0.3837 - accuracy: 0.8346 - val_loss: 0.3605 - val_accuracy: 0.8417 Epoch 3/10 391/391 [==============================] - 54s 134ms/step - loss: 0.3345 - accuracy: 0.8572 - val_loss: 0.3385 - val_accuracy: 0.8396 Epoch 4/10 391/391 [==============================] - 52s 130ms/step - loss: 0.3204 - accuracy: 0.8643 - val_loss: 0.3282 - val_accuracy: 0.8568 Epoch 5/10 391/391 [==============================] - 52s 131ms/step - loss: 0.3103 - accuracy: 0.8678 - val_loss: 0.3253 - val_accuracy: 0.8547 Epoch 6/10 391/391 [==============================] - 52s 130ms/step - loss: 0.3068 - accuracy: 0.8693 - val_loss: 0.3233 - val_accuracy: 0.8536 Epoch 7/10 391/391 [==============================] - 52s 131ms/step - loss: 0.3034 - accuracy: 0.8683 - val_loss: 0.3234 - val_accuracy: 0.8641 Epoch 8/10 391/391 [==============================] - 53s 133ms/step - loss: 0.2982 - accuracy: 0.8724 - val_loss: 0.3151 - val_accuracy: 0.8599 Epoch 9/10 391/391 [==============================] - 52s 130ms/step - loss: 0.2938 - accuracy: 0.8734 - val_loss: 0.3211 - val_accuracy: 0.8510 Epoch 10/10 391/391 [==============================] - 51s 130ms/step - loss: 0.2947 - accuracy: 0.8740 - val_loss: 0.3337 - val_accuracy: 0.8651
test_loss, test_acc = model.evaluate(test_dataset)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 24s 61ms/step - loss: 0.3354 - accuracy: 0.8635 Test Loss: 0.33541229367256165 Test Accuracy: 0.8635200262069702
# predict on a sample text without padding.
sample_text = ('The movie was not good. The animation and the graphics '
'were terrible. I would not recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions)
[[-1.8224874]]
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
Check out other existing recurrent layers such as GRU layers.
If you're interested in building custom RNNs, see the Keras RNN Guide.