Watch keynotes, product sessions, workshops, and more from Google I/O See playlist

Text classification with an RNN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This text classification tutorial trains a recurrent neural network on the IMDB large movie review dataset for sentiment analysis.

Setup

import numpy as np

import tensorflow_datasets as tfds
import tensorflow as tf

tfds.disable_progress_bar()

Import matplotlib and create a helper function to plot graphs:

import matplotlib.pyplot as plt


def plot_graphs(history, metric):
  plt.plot(history.history[metric])
  plt.plot(history.history['val_'+metric], '')
  plt.xlabel("Epochs")
  plt.ylabel(metric)
  plt.legend([metric, 'val_'+metric])

Setup input pipeline

The IMDB large movie review dataset is a binary classification dataset—all the reviews have either a positive or negative sentiment.

Download the dataset using TFDS. See the loading text tutorial for details on how to load this sort of data manually.

dataset, info = tfds.load('imdb_reviews', with_info=True,
                          as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

train_dataset.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None),
 TensorSpec(shape=(), dtype=tf.int64, name=None))

Initially this returns a dataset of (text, label pairs):

for example, label in train_dataset.take(1):
  print('text: ', example.numpy())
  print('label: ', label.numpy())
text:  b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
label:  0

Next shuffle the data for training and create batches of these (text, label) pairs:

BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
for example, label in train_dataset.take(1):
  print('texts: ', example.numpy()[:3])
  print()
  print('labels: ', label.numpy()[:3])
texts:  [b"I watched this on the tube last night. The actor's involved first caught my attention. The first scenes were attention getters. Some funny some sad. Good character development. I felt that the latter third of the film diverged. If it was not for the early part of the movie I would have stopped watching. I kept watching wanting to how how it tied together.<br /><br />Unfortunately I feel that it never happened. I especially did not like the extend period that several of the character were talking yiddish (?). Was that the other shoe?<br /><br />Would I recommend? No, I think not. As other reviewers mention much of the slang is dated (60's jive) but it was not too distracting. The ending totally turned me off."
 b'Jason Bourne sits in a dusty room in with blood on his hands, trying to make sense of what he\'s just done. Meanwhile, a CIA chief in NYC outlines the agency\'s response to what\'s just happened on screen. An American flag stands proudly on the centre of his desk in the foreground of the shot, but as he speaks, it slips out of focus as his plan veers into morally dubious territory, as if it doesn\'t want to be associated with the course of action the government man decides is necessary in the interests of national security.<br /><br />This shot effectively captures the mood of the film. As well as portraying Bourne\'s quest to find out how he became Jason Bourne, Ultimatum is also an examination of the human costs of the measures taken to protect us in the interests of stability and security.<br /><br />It is also probably the best film you\'ll see in the cinema this year. <br /><br />It\'s just so intense. Bourne says to Simon Ross (Considine) "This isn\'t some newspaper story, this is real" and in the audience you almost believe him. The camera shakes, but remains steady enough for you to see everything and feel like you\'re there with Bourne as he tries to elude his pursuers, and the performances are so good that these guys seem as though they are the characters they\'re portraying, instead of just being actors performing well-written roles. The action scenes are so brutally fast-paced and well choreographed that they seem instinctive instead of planned to the minutest movement; the stunt-work is nothing short of amazing.<br /><br />The pacing is just incredible. It keeps driving forward towards its conclusion, but not so fast that it leaves you struggling to piece together the plot; the script delivers the information you need as quickly and clearly as possible before moving on to the next tense action set-piece. While they\'re often simple (the Waterloo sequence is essentially just a man on a phone being watched by a man on a phone) they\'re charged with such dramatic intensity that you can\'t take your eyes off them. The film is just so focused on powering forwards that you can\'t help being swept along by it.<br /><br />With its intense action set-pieces, brilliantly paced storyline, and intelligent examination of the decisions made in the name of national security, the Bourne series is one that accurately captures the ambiguities of our age. Ultimatum is its peak.'
 b'Mirage (1990) is a very rare horror/chiller from 1990, released here in the UK on the "New World Video" label.<br /><br />It\'s a desert based horror film about a group of young friends who are partying for the weekend, only to be killed off one by one by an unknown force who drives a menacing black truck!!! This film has some creepy scenes, and some gore here and there, but i have to say that the acting was so lame, even by low budget standards! But the film was oddly addictive and i liked it, and i never fell asleep or turned it off, which is always a good sign! I nearly gave this movie 6/10, but seeing as it built up steam along the way, had some good moments of gore and suspense, had some good nudity, and the fact that the blonde in the main female role was a hottie too! i\'ll give it 7/10.']

labels:  [0 1 1]

Create the text encoder

The raw text loaded by tfds needs to be processed before it can be used in a model. The simplest way to process text for training is using the experimental.preprocessing.TextVectorization layer. This layer has many capabilities, but this tutorial sticks to the default behavior.

Create the layer, and pass the dataset's text to the layer's .adapt method:

VOCAB_SIZE = 1000
encoder = tf.keras.layers.experimental.preprocessing.TextVectorization(
    max_tokens=VOCAB_SIZE)
encoder.adapt(train_dataset.map(lambda text, label: text))

The .adapt method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency:

vocab = np.array(encoder.get_vocabulary())
vocab[:20]
array(['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i',
       'this', 'that', 'br', 'was', 'as', 'for', 'with', 'movie', 'but'],
      dtype='<U14')

Once the vocabulary is set, the layer can encode text into indices. The tensors of indices are 0-padded to the longest sequence in the batch (unless you set a fixed output_sequence_length):

encoded_example = encoder(example)[:3].numpy()
encoded_example
array([[ 10, 284,  11, ...,   0,   0,   0],
       [  1,   1,   1, ...,   0,   0,   0],
       [  1,   1,   7, ...,   0,   0,   0]])

With the default settings, the process is not completely reversible. There are three main reasons for that:

  1. The default value for preprocessing.TextVectorization's standardize argument is "lower_and_strip_punctuation".
  2. The limited vocabulary size and lack of character-based fallback results in some unknown tokens.
for n in range(3):
  print("Original: ", example[n].numpy())
  print("Round-trip: ", " ".join(vocab[encoded_example[n]]))
  print()
Original:  b"I watched this on the tube last night. The actor's involved first caught my attention. The first scenes were attention getters. Some funny some sad. Good character development. I felt that the latter third of the film diverged. If it was not for the early part of the movie I would have stopped watching. I kept watching wanting to how how it tied together.<br /><br />Unfortunately I feel that it never happened. I especially did not like the extend period that several of the character were talking yiddish (?). Was that the other shoe?<br /><br />Would I recommend? No, I think not. As other reviewers mention much of the slang is dated (60's jive) but it was not too distracting. The ending totally turned me off."
Round-trip:  i watched this on the [UNK] last night the actors involved first [UNK] my attention the first scenes were attention [UNK] some funny some sad good character development i felt that the [UNK] third of the film [UNK] if it was not for the early part of the movie i would have [UNK] watching i kept watching [UNK] to how how it [UNK] [UNK] br unfortunately i feel that it never happened i especially did not like the [UNK] period that several of the character were talking [UNK] was that the other [UNK] br would i recommend no i think not as other [UNK] mention much of the [UNK] is [UNK] [UNK] [UNK] but it was not too [UNK] the ending totally turned me off                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               

Original:  b'Jason Bourne sits in a dusty room in with blood on his hands, trying to make sense of what he\'s just done. Meanwhile, a CIA chief in NYC outlines the agency\'s response to what\'s just happened on screen. An American flag stands proudly on the centre of his desk in the foreground of the shot, but as he speaks, it slips out of focus as his plan veers into morally dubious territory, as if it doesn\'t want to be associated with the course of action the government man decides is necessary in the interests of national security.<br /><br />This shot effectively captures the mood of the film. As well as portraying Bourne\'s quest to find out how he became Jason Bourne, Ultimatum is also an examination of the human costs of the measures taken to protect us in the interests of stability and security.<br /><br />It is also probably the best film you\'ll see in the cinema this year. <br /><br />It\'s just so intense. Bourne says to Simon Ross (Considine) "This isn\'t some newspaper story, this is real" and in the audience you almost believe him. The camera shakes, but remains steady enough for you to see everything and feel like you\'re there with Bourne as he tries to elude his pursuers, and the performances are so good that these guys seem as though they are the characters they\'re portraying, instead of just being actors performing well-written roles. The action scenes are so brutally fast-paced and well choreographed that they seem instinctive instead of planned to the minutest movement; the stunt-work is nothing short of amazing.<br /><br />The pacing is just incredible. It keeps driving forward towards its conclusion, but not so fast that it leaves you struggling to piece together the plot; the script delivers the information you need as quickly and clearly as possible before moving on to the next tense action set-piece. While they\'re often simple (the Waterloo sequence is essentially just a man on a phone being watched by a man on a phone) they\'re charged with such dramatic intensity that you can\'t take your eyes off them. The film is just so focused on powering forwards that you can\'t help being swept along by it.<br /><br />With its intense action set-pieces, brilliantly paced storyline, and intelligent examination of the decisions made in the name of national security, the Bourne series is one that accurately captures the ambiguities of our age. Ultimatum is its peak.'
Round-trip:  [UNK] [UNK] [UNK] in a [UNK] room in with blood on his hands trying to make sense of what hes just done [UNK] a [UNK] [UNK] in [UNK] [UNK] the [UNK] [UNK] to whats just happened on screen an american [UNK] [UNK] [UNK] on the [UNK] of his [UNK] in the [UNK] of the shot but as he [UNK] it [UNK] out of [UNK] as his [UNK] [UNK] into [UNK] [UNK] [UNK] as if it doesnt want to be [UNK] with the course of action the [UNK] man [UNK] is [UNK] in the [UNK] of [UNK] [UNK] br this shot [UNK] [UNK] the [UNK] of the film as well as [UNK] [UNK] [UNK] to find out how he became [UNK] [UNK] [UNK] is also an [UNK] of the human [UNK] of the [UNK] taken to [UNK] us in the [UNK] of [UNK] and [UNK] br it is also probably the best film youll see in the cinema this year br br its just so [UNK] [UNK] says to [UNK] [UNK] [UNK] this isnt some [UNK] story this is real and in the audience you almost believe him the camera [UNK] but [UNK] [UNK] enough for you to see everything and feel like youre there with [UNK] as he tries to [UNK] his [UNK] and the performances are so good that these guys seem as though they are the characters theyre [UNK] instead of just being actors [UNK] [UNK] roles the action scenes are so [UNK] [UNK] and well [UNK] that they seem [UNK] instead of [UNK] to the [UNK] [UNK] the [UNK] is nothing short of [UNK] br the [UNK] is just [UNK] it keeps [UNK] forward towards its [UNK] but not so fast that it leaves you [UNK] to piece together the plot the script [UNK] the [UNK] you need as quickly and clearly as possible before moving on to the next [UNK] action [UNK] while theyre often simple the [UNK] sequence is [UNK] just a man on a [UNK] being watched by a man on a [UNK] theyre [UNK] with such dramatic [UNK] that you cant take your eyes off them the film is just so [UNK] on [UNK] [UNK] that you cant help being [UNK] along by itbr br with its [UNK] action [UNK] [UNK] [UNK] storyline and [UNK] [UNK] of the [UNK] made in the name of [UNK] [UNK] the [UNK] series is one that [UNK] [UNK] the [UNK] of our age [UNK] is its [UNK]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  

Original:  b'Mirage (1990) is a very rare horror/chiller from 1990, released here in the UK on the "New World Video" label.<br /><br />It\'s a desert based horror film about a group of young friends who are partying for the weekend, only to be killed off one by one by an unknown force who drives a menacing black truck!!! This film has some creepy scenes, and some gore here and there, but i have to say that the acting was so lame, even by low budget standards! But the film was oddly addictive and i liked it, and i never fell asleep or turned it off, which is always a good sign! I nearly gave this movie 6/10, but seeing as it built up steam along the way, had some good moments of gore and suspense, had some good nudity, and the fact that the blonde in the main female role was a hottie too! i\'ll give it 7/10.'
Round-trip:  [UNK] [UNK] is a very [UNK] [UNK] from [UNK] released here in the [UNK] on the new world video [UNK] br its a [UNK] based horror film about a group of young friends who are [UNK] for the [UNK] only to be killed off one by one by an [UNK] [UNK] who [UNK] a [UNK] black [UNK] this film has some creepy scenes and some gore here and there but i have to say that the acting was so lame even by low budget [UNK] but the film was [UNK] [UNK] and i liked it and i never [UNK] [UNK] or turned it off which is always a good [UNK] i nearly gave this movie [UNK] but seeing as it [UNK] up [UNK] along the way had some good moments of gore and suspense had some good [UNK] and the fact that the [UNK] in the main female role was a [UNK] too ill give it [UNK]

Create the model

A drawing of the information flow in the model

Above is a diagram of the model.

  1. This model can be build as a tf.keras.Sequential.

  2. The first layer is the encoder, which converts the text to a sequence of token indices.

  3. After the encoder is an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.

    This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a tf.keras.layers.Dense layer.

  4. A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input on the next timestep.

    The tf.keras.layers.Bidirectional wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the final output.

    • The main advantage of a bidirectional RNN is that the signal from the beginning of the input doesn't need to be processed all the way through every timestep to affect the output.

    • The main disadvantage of a bidirectional RNN is that you can't efficiently stream predictions as words are being added to the end.

  5. After the RNN has converted the sequence to a single vector the two layers.Dense do some final processing, and convert from this vector representation to a single logit as the classification output.

The code to implement this is below:

model = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(
        input_dim=len(encoder.get_vocabulary()),
        output_dim=64,
        # Use masking to handle the variable sequence lengths
        mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])

Please note that Keras sequential model is used here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check Keras RNN guide for more details.

The embedding layer uses masking to handle the varying sequence-lengths. All the layers after the Embedding support masking:

print([layer.supports_masking for layer in model.layers])
[False, True, True, True, True]

To confirm that this works as expected, evaluate a sentence twice. First, alone so there's no padding to mask:

# predict on a sample text without padding.

sample_text = ('The movie was cool. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions[0])
[0.0060511]

Now, evaluate it again in a batch with a longer sentence. The result should be identical:

# predict on a sample text with padding

padding = "the " * 2000
predictions = model.predict(np.array([sample_text, padding]))
print(predictions[0])
[0.00605109]

Compile the Keras model to configure the training process:

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])

Train the model

history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset,
                    validation_steps=30)
Epoch 1/10
391/391 [==============================] - 41s 87ms/step - loss: 0.6417 - accuracy: 0.5760 - val_loss: 0.4992 - val_accuracy: 0.7354
Epoch 2/10
391/391 [==============================] - 33s 82ms/step - loss: 0.4060 - accuracy: 0.8121 - val_loss: 0.3681 - val_accuracy: 0.8344
Epoch 3/10
391/391 [==============================] - 32s 81ms/step - loss: 0.3429 - accuracy: 0.8481 - val_loss: 0.3519 - val_accuracy: 0.8266
Epoch 4/10
391/391 [==============================] - 32s 81ms/step - loss: 0.3232 - accuracy: 0.8597 - val_loss: 0.3432 - val_accuracy: 0.8354
Epoch 5/10
391/391 [==============================] - 32s 80ms/step - loss: 0.3171 - accuracy: 0.8608 - val_loss: 0.3440 - val_accuracy: 0.8630
Epoch 6/10
391/391 [==============================] - 32s 79ms/step - loss: 0.3104 - accuracy: 0.8657 - val_loss: 0.3327 - val_accuracy: 0.8484
Epoch 7/10
391/391 [==============================] - 32s 80ms/step - loss: 0.3051 - accuracy: 0.8683 - val_loss: 0.3519 - val_accuracy: 0.8333
Epoch 8/10
391/391 [==============================] - 32s 80ms/step - loss: 0.3058 - accuracy: 0.8681 - val_loss: 0.3214 - val_accuracy: 0.8510
Epoch 9/10
391/391 [==============================] - 32s 81ms/step - loss: 0.2996 - accuracy: 0.8721 - val_loss: 0.3328 - val_accuracy: 0.8458
Epoch 10/10
391/391 [==============================] - 32s 80ms/step - loss: 0.3004 - accuracy: 0.8698 - val_loss: 0.3207 - val_accuracy: 0.8568
test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 15s 37ms/step - loss: 0.3183 - accuracy: 0.8599
Test Loss: 0.3183472752571106
Test Accuracy: 0.8598799705505371
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.ylim(None, 1)
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
plt.ylim(0, None)
(0.0, 0.65884320884943)

png

Run a prediction on a new sentence:

If the prediction is >= 0.0, it is positive else it is negative.

sample_text = ('The movie was cool. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))

Stack two or more LSTM layers

Keras recurrent layers have two available modes that are controlled by the return_sequences constructor argument:

  • If False it returns only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)). This is the default, used in the previous model.

  • If True the full sequences of successive outputs for each timestep is returned (a 3D tensor of shape (batch_size, timesteps, output_features)).

Here is what the flow of information looks like with return_sequences=True:

layered_bidirectional

The interesting thing about using an RNN with return_sequences=True is that the output still has 3-axes, like the input, so it can be passed to another RNN layer, like this:

model = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(len(encoder.get_vocabulary()), 64, mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64,  return_sequences=True)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1)
])
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset,
                    validation_steps=30)
Epoch 1/10
391/391 [==============================] - 73s 148ms/step - loss: 0.6193 - accuracy: 0.5947 - val_loss: 0.4267 - val_accuracy: 0.8266
Epoch 2/10
391/391 [==============================] - 54s 138ms/step - loss: 0.3748 - accuracy: 0.8359 - val_loss: 0.3740 - val_accuracy: 0.8526
Epoch 3/10
391/391 [==============================] - 54s 137ms/step - loss: 0.3385 - accuracy: 0.8553 - val_loss: 0.3397 - val_accuracy: 0.8385
Epoch 4/10
391/391 [==============================] - 53s 134ms/step - loss: 0.3167 - accuracy: 0.8651 - val_loss: 0.3255 - val_accuracy: 0.8589
Epoch 5/10
391/391 [==============================] - 53s 133ms/step - loss: 0.3118 - accuracy: 0.8660 - val_loss: 0.3213 - val_accuracy: 0.8641
Epoch 6/10
391/391 [==============================] - 53s 134ms/step - loss: 0.3070 - accuracy: 0.8679 - val_loss: 0.3294 - val_accuracy: 0.8661
Epoch 7/10
391/391 [==============================] - 54s 137ms/step - loss: 0.3044 - accuracy: 0.8725 - val_loss: 0.3184 - val_accuracy: 0.8604
Epoch 8/10
391/391 [==============================] - 54s 138ms/step - loss: 0.3024 - accuracy: 0.8716 - val_loss: 0.3419 - val_accuracy: 0.8365
Epoch 9/10
391/391 [==============================] - 55s 138ms/step - loss: 0.2999 - accuracy: 0.8720 - val_loss: 0.3199 - val_accuracy: 0.8620
Epoch 10/10
391/391 [==============================] - 54s 136ms/step - loss: 0.2938 - accuracy: 0.8750 - val_loss: 0.3202 - val_accuracy: 0.8469
test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 25s 63ms/step - loss: 0.3171 - accuracy: 0.8503
Test Loss: 0.31706351041793823
Test Accuracy: 0.8503199815750122
# predict on a sample text without padding.

sample_text = ('The movie was not good. The animation and the graphics '
               'were terrible. I would not recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions)
[[-2.0174394]]
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')

png

Check out other existing recurrent layers such as GRU layers.

If you're interestied in building custom RNNs, see the Keras RNN Guide.