![]() |
![]() |
![]() |
![]() |
This text classification tutorial trains a recurrent neural network on the IMDB large movie review dataset for sentiment analysis.
Setup
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
2023-11-16 13:53:10.615881: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 13:53:10.615930: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 13:53:10.617515: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Import matplotlib
and create a helper function to plot graphs:
import matplotlib.pyplot as plt
def plot_graphs(history, metric):
plt.plot(history.history[metric])
plt.plot(history.history['val_'+metric], '')
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend([metric, 'val_'+metric])
Setup input pipeline
The IMDB large movie review dataset is a binary classification dataset—all the reviews have either a positive or negative sentiment.
Download the dataset using TFDS. See the loading text tutorial for details on how to load this sort of data manually.
dataset, info = tfds.load('imdb_reviews', with_info=True,
as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
train_dataset.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))
Initially this returns a dataset of (text, label pairs):
for example, label in train_dataset.take(1):
print('text: ', example.numpy())
print('label: ', label.numpy())
text: b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it." label: 0
Next shuffle the data for training and create batches of these (text, label)
pairs:
BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
for example, label in train_dataset.take(1):
print('texts: ', example.numpy()[:3])
print()
print('labels: ', label.numpy()[:3])
texts: [b'Watching beautiful women sneaking around, playing cops and robbers is one of the most delightful guilty pleasures the medium film lets me enjoy. So The House on Carroll Street was not entirely a waste of time, although the story is contrived and the screenplay uninspired and somewhat irritating.<br /><br />There are many allusions to different Hitchcock pictures, not least the choice of Kelly McGillis in the starring role. She is dressed up as Grace Kelly, and she is not far off the mark. Not at all. But her character is not convincing. The way she is introduced to the audience, she should be someone with political convictions and a purpose in life. After all the movie deals with a clearly defined time period, true events and a specific issue. But the story degenerates within the first minutes into a sorry run-off-the-mill crime story with unbelievable coincidences, high predictability and a set of two dimensional characters. This is all the more regrettable, as the performances of the actors are good, as are the photography and the set design.<br /><br />The finale in Central Station, New York is breath taking. It starts in the subterranean section and then moves up to the roof. The movie can be praised for its good use of architecture.' b'A group of people are invited to there high school reunion, but after they arrive they discover it to be a scam by an old classmate they played an almost fatal prank on. Now, he seeks to get revenge on all those that hurt him by sealing all the exits and cutting off all telephone lines.<br /><br />Dark slasher film with an unexceptional premise. Bringing it up a notch are a few good performances, some rather creative death scenes, plenty of excitement & scares, some humor and an original ending.<br /><br />Unrated for Extreme Violence, Graphic Nudity, Sexual Situations, Profanity and Drug Use.' b'The short that starts this film is the true footage of a guy named Gary, apparently it was taken randomly in the parking lot of a television station where Gary works in the town of Beaver. Gary is a little "different"; he is an impersonator and drives an old Chevy named Farrah (after Fawcett). Lo and behold the filmmaker gets a letter from Gary some time later inviting him to return to Beaver to get some footage of the local talent contest he has put together, including Gary\'s staggering performace as Olivia Newton Dawn. Oh, my. The two shorts that follow are Gary\'s story, the same one you just witnessed only the first is portrayed by Sean Penn and the second by Crispin Glover titled "The Orkly Kid." If you are in the mood for making fun of someone this is definitely the film to watch. I was doubled over with laughter through most of it, especially Crispins performance which could definitely stand on it\'s own. When it was over, I had to rewind the film to once again watch the real Gary and all his shining idiocy. Although Olivia was the focus, I would have liked to have seen one of the "fictitious" shorts take a jab at Gary\'s Barry Manilow impersonation, whic h was equally ridiculous.'] labels: [0 0 1]
Create the text encoder
The raw text loaded by tfds
needs to be processed before it can be used in a model. The simplest way to process text for training is using the TextVectorization
layer. This layer has many capabilities, but this tutorial sticks to the default behavior.
Create the layer, and pass the dataset's text to the layer's .adapt
method:
VOCAB_SIZE = 1000
encoder = tf.keras.layers.TextVectorization(
max_tokens=VOCAB_SIZE)
encoder.adapt(train_dataset.map(lambda text, label: text))
The .adapt
method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency:
vocab = np.array(encoder.get_vocabulary())
vocab[:20]
array(['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i', 'this', 'that', 'br', 'was', 'as', 'for', 'with', 'movie', 'but'], dtype='<U14')
Once the vocabulary is set, the layer can encode text into indices. The tensors of indices are 0-padded to the longest sequence in the batch (unless you set a fixed output_sequence_length
):
encoded_example = encoder(example)[:3].numpy()
encoded_example
array([[147, 300, 362, ..., 0, 0, 0], [ 4, 579, 5, ..., 0, 0, 0], [ 2, 348, 12, ..., 0, 0, 0]])
With the default settings, the process is not completely reversible. There are three main reasons for that:
- The default value for
preprocessing.TextVectorization
'sstandardize
argument is"lower_and_strip_punctuation"
. - The limited vocabulary size and lack of character-based fallback results in some unknown tokens.
for n in range(3):
print("Original: ", example[n].numpy())
print("Round-trip: ", " ".join(vocab[encoded_example[n]]))
print()
Original: b'Watching beautiful women sneaking around, playing cops and robbers is one of the most delightful guilty pleasures the medium film lets me enjoy. So The House on Carroll Street was not entirely a waste of time, although the story is contrived and the screenplay uninspired and somewhat irritating.<br /><br />There are many allusions to different Hitchcock pictures, not least the choice of Kelly McGillis in the starring role. She is dressed up as Grace Kelly, and she is not far off the mark. Not at all. But her character is not convincing. The way she is introduced to the audience, she should be someone with political convictions and a purpose in life. After all the movie deals with a clearly defined time period, true events and a specific issue. But the story degenerates within the first minutes into a sorry run-off-the-mill crime story with unbelievable coincidences, high predictability and a set of two dimensional characters. This is all the more regrettable, as the performances of the actors are good, as are the photography and the set design.<br /><br />The finale in Central Station, New York is breath taking. It starts in the subterranean section and then moves up to the roof. The movie can be praised for its good use of architecture.' Round-trip: watching beautiful women [UNK] around playing [UNK] and [UNK] is one of the most [UNK] [UNK] [UNK] the [UNK] film lets me enjoy so the house on [UNK] street was not [UNK] a waste of time although the story is [UNK] and the screenplay [UNK] and somewhat [UNK] br there are many [UNK] to different [UNK] [UNK] not least the [UNK] of [UNK] [UNK] in the [UNK] role she is [UNK] up as [UNK] [UNK] and she is not far off the mark not at all but her character is not [UNK] the way she is [UNK] to the audience she should be someone with political [UNK] and a [UNK] in life after all the movie [UNK] with a clearly [UNK] time period true events and a [UNK] [UNK] but the story [UNK] within the first minutes into a sorry [UNK] crime story with [UNK] [UNK] high [UNK] and a set of two [UNK] characters this is all the more [UNK] as the performances of the actors are good as are the [UNK] and the set [UNK] br the [UNK] in [UNK] [UNK] new york is [UNK] taking it starts in the [UNK] [UNK] and then [UNK] up to the [UNK] the movie can be [UNK] for its good use of [UNK] Original: b'A group of people are invited to there high school reunion, but after they arrive they discover it to be a scam by an old classmate they played an almost fatal prank on. Now, he seeks to get revenge on all those that hurt him by sealing all the exits and cutting off all telephone lines.<br /><br />Dark slasher film with an unexceptional premise. Bringing it up a notch are a few good performances, some rather creative death scenes, plenty of excitement & scares, some humor and an original ending.<br /><br />Unrated for Extreme Violence, Graphic Nudity, Sexual Situations, Profanity and Drug Use.' Round-trip: a group of people are [UNK] to there high school [UNK] but after they [UNK] they [UNK] it to be a [UNK] by an old [UNK] they played an almost [UNK] [UNK] on now he [UNK] to get [UNK] on all those that [UNK] him by [UNK] all the [UNK] and [UNK] off all [UNK] [UNK] br dark [UNK] film with an [UNK] premise [UNK] it up a [UNK] are a few good performances some rather [UNK] death scenes plenty of [UNK] [UNK] some humor and an original [UNK] br [UNK] for [UNK] violence [UNK] [UNK] sexual [UNK] [UNK] and [UNK] use Original: b'The short that starts this film is the true footage of a guy named Gary, apparently it was taken randomly in the parking lot of a television station where Gary works in the town of Beaver. Gary is a little "different"; he is an impersonator and drives an old Chevy named Farrah (after Fawcett). Lo and behold the filmmaker gets a letter from Gary some time later inviting him to return to Beaver to get some footage of the local talent contest he has put together, including Gary\'s staggering performace as Olivia Newton Dawn. Oh, my. The two shorts that follow are Gary\'s story, the same one you just witnessed only the first is portrayed by Sean Penn and the second by Crispin Glover titled "The Orkly Kid." If you are in the mood for making fun of someone this is definitely the film to watch. I was doubled over with laughter through most of it, especially Crispins performance which could definitely stand on it\'s own. When it was over, I had to rewind the film to once again watch the real Gary and all his shining idiocy. Although Olivia was the focus, I would have liked to have seen one of the "fictitious" shorts take a jab at Gary\'s Barry Manilow impersonation, whic h was equally ridiculous.' Round-trip: the short that starts this film is the true footage of a guy named [UNK] apparently it was taken [UNK] in the [UNK] lot of a television [UNK] where [UNK] works in the town of [UNK] [UNK] is a little different he is an [UNK] and [UNK] an old [UNK] named [UNK] after [UNK] [UNK] and [UNK] the [UNK] gets a [UNK] from [UNK] some time later [UNK] him to return to [UNK] to get some footage of the local talent [UNK] he has put together including [UNK] [UNK] [UNK] as [UNK] [UNK] [UNK] oh my the two [UNK] that follow are [UNK] story the same one you just [UNK] only the first is portrayed by [UNK] [UNK] and the second by [UNK] [UNK] [UNK] the [UNK] kid if you are in the [UNK] for making fun of someone this is definitely the film to watch i was [UNK] over with [UNK] through most of it especially [UNK] performance which could definitely stand on its own when it was over i had to [UNK] the film to once again watch the real [UNK] and all his [UNK] [UNK] although [UNK] was the [UNK] i would have liked to have seen one of the [UNK] [UNK] take a [UNK] at [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] was [UNK] ridiculous
Create the model
Above is a diagram of the model.
This model can be build as a
tf.keras.Sequential
.The first layer is the
encoder
, which converts the text to a sequence of token indices.After the encoder is an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.
This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a
tf.keras.layers.Dense
layer.A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input on the next timestep.
The
tf.keras.layers.Bidirectional
wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the final output.The main advantage of a bidirectional RNN is that the signal from the beginning of the input doesn't need to be processed all the way through every timestep to affect the output.
The main disadvantage of a bidirectional RNN is that you can't efficiently stream predictions as words are being added to the end.
After the RNN has converted the sequence to a single vector the two
layers.Dense
do some final processing, and convert from this vector representation to a single logit as the classification output.
The code to implement this is below:
model = tf.keras.Sequential([
encoder,
tf.keras.layers.Embedding(
input_dim=len(encoder.get_vocabulary()),
output_dim=64,
# Use masking to handle the variable sequence lengths
mask_zero=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
Please note that Keras sequential model is used here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check Keras RNN guide for more details.
The embedding layer uses masking to handle the varying sequence-lengths. All the layers after the Embedding
support masking:
print([layer.supports_masking for layer in model.layers])
[False, True, True, True, True]
To confirm that this works as expected, evaluate a sentence twice. First, alone so there's no padding to mask:
# predict on a sample text without padding.
sample_text = ('The movie was cool. The animation and the graphics '
'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions[0])
1/1 [==============================] - 3s 3s/step [0.00856274]
Now, evaluate it again in a batch with a longer sentence. The result should be identical:
# predict on a sample text with padding
padding = "the " * 2000
predictions = model.predict(np.array([sample_text, padding]))
print(predictions[0])
1/1 [==============================] - 0s 86ms/step [0.00856275]
Compile the Keras model to configure the training process:
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
Train the model
history = model.fit(train_dataset, epochs=10,
validation_data=test_dataset,
validation_steps=30)
Epoch 1/10 2023-11-16 13:53:32.243442: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } for Tuple type infernce function 0 while inferring type of node 'cond_36/output/_23' WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1700142813.152065 83765 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 391/391 [==============================] - 43s 88ms/step - loss: 0.6566 - accuracy: 0.5580 - val_loss: 0.5489 - val_accuracy: 0.7505 Epoch 2/10 391/391 [==============================] - 21s 54ms/step - loss: 0.4354 - accuracy: 0.7937 - val_loss: 0.3724 - val_accuracy: 0.8234 Epoch 3/10 391/391 [==============================] - 22s 54ms/step - loss: 0.3451 - accuracy: 0.8468 - val_loss: 0.3403 - val_accuracy: 0.8521 Epoch 4/10 391/391 [==============================] - 21s 52ms/step - loss: 0.3224 - accuracy: 0.8601 - val_loss: 0.3332 - val_accuracy: 0.8573 Epoch 5/10 391/391 [==============================] - 21s 52ms/step - loss: 0.3168 - accuracy: 0.8623 - val_loss: 0.3291 - val_accuracy: 0.8620 Epoch 6/10 391/391 [==============================] - 21s 52ms/step - loss: 0.3088 - accuracy: 0.8658 - val_loss: 0.3370 - val_accuracy: 0.8615 Epoch 7/10 391/391 [==============================] - 22s 52ms/step - loss: 0.3060 - accuracy: 0.8692 - val_loss: 0.3271 - val_accuracy: 0.8448 Epoch 8/10 391/391 [==============================] - 21s 52ms/step - loss: 0.3033 - accuracy: 0.8714 - val_loss: 0.3249 - val_accuracy: 0.8583 Epoch 9/10 391/391 [==============================] - 21s 51ms/step - loss: 0.3017 - accuracy: 0.8695 - val_loss: 0.3293 - val_accuracy: 0.8385 Epoch 10/10 391/391 [==============================] - 21s 52ms/step - loss: 0.2995 - accuracy: 0.8717 - val_loss: 0.3217 - val_accuracy: 0.8630
test_loss, test_acc = model.evaluate(test_dataset)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 9s 23ms/step - loss: 0.3167 - accuracy: 0.8624 Test Loss: 0.3167201280593872 Test Accuracy: 0.8623600006103516
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.ylim(None, 1)
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
plt.ylim(0, None)
(0.0, 0.6744414046406746)
Run a prediction on a new sentence:
If the prediction is >= 0.0, it is positive else it is negative.
sample_text = ('The movie was cool. The animation and the graphics '
'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
1/1 [==============================] - 2s 2s/step
Stack two or more LSTM layers
Keras recurrent layers have two available modes that are controlled by the return_sequences
constructor argument:
If
False
it returns only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)). This is the default, used in the previous model.If
True
the full sequences of successive outputs for each timestep is returned (a 3D tensor of shape(batch_size, timesteps, output_features)
).
Here is what the flow of information looks like with return_sequences=True
:
The interesting thing about using an RNN
with return_sequences=True
is that the output still has 3-axes, like the input, so it can be passed to another RNN layer, like this:
model = tf.keras.Sequential([
encoder,
tf.keras.layers.Embedding(len(encoder.get_vocabulary()), 64, mask_zero=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1)
])
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10,
validation_data=test_dataset,
validation_steps=30)
Epoch 1/10 391/391 [==============================] - 66s 131ms/step - loss: 0.6284 - accuracy: 0.5935 - val_loss: 0.4341 - val_accuracy: 0.8031 Epoch 2/10 391/391 [==============================] - 40s 101ms/step - loss: 0.3818 - accuracy: 0.8336 - val_loss: 0.3429 - val_accuracy: 0.8474 Epoch 3/10 391/391 [==============================] - 40s 100ms/step - loss: 0.3369 - accuracy: 0.8557 - val_loss: 0.3489 - val_accuracy: 0.8458 Epoch 4/10 391/391 [==============================] - 40s 101ms/step - loss: 0.3265 - accuracy: 0.8590 - val_loss: 0.3239 - val_accuracy: 0.8589 Epoch 5/10 391/391 [==============================] - 39s 99ms/step - loss: 0.3123 - accuracy: 0.8678 - val_loss: 0.3265 - val_accuracy: 0.8500 Epoch 6/10 391/391 [==============================] - 40s 100ms/step - loss: 0.3072 - accuracy: 0.8690 - val_loss: 0.3242 - val_accuracy: 0.8604 Epoch 7/10 391/391 [==============================] - 40s 100ms/step - loss: 0.3060 - accuracy: 0.8673 - val_loss: 0.3211 - val_accuracy: 0.8464 Epoch 8/10 391/391 [==============================] - 40s 100ms/step - loss: 0.3011 - accuracy: 0.8724 - val_loss: 0.3169 - val_accuracy: 0.8531 Epoch 9/10 391/391 [==============================] - 39s 99ms/step - loss: 0.2973 - accuracy: 0.8717 - val_loss: 0.3248 - val_accuracy: 0.8635 Epoch 10/10 391/391 [==============================] - 40s 100ms/step - loss: 0.2953 - accuracy: 0.8734 - val_loss: 0.3242 - val_accuracy: 0.8672
test_loss, test_acc = model.evaluate(test_dataset)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 17s 42ms/step - loss: 0.3255 - accuracy: 0.8652 Test Loss: 0.325457364320755 Test Accuracy: 0.8652399778366089
# predict on a sample text without padding.
sample_text = ('The movie was not good. The animation and the graphics '
'were terrible. I would not recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions)
1/1 [==============================] - 4s 4s/step [[-1.6299357]]
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
Check out other existing recurrent layers such as GRU layers.
If you're interested in building custom RNNs, see the Keras RNN Guide.