Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Text classification with an RNN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This text classification tutorial trains a recurrent neural network on the IMDB large movie review dataset for sentiment analysis.

from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow_datasets as tfds
import tensorflow as tf

Import matplotlib and create a helper function to plot graphs:

import matplotlib.pyplot as plt


def plot_graphs(history, string):
  plt.plot(history.history[string])
  plt.plot(history.history['val_'+string])
  plt.xlabel("Epochs")
  plt.ylabel(string)
  plt.legend([string, 'val_'+string])
  plt.show()

Setup input pipeline

The IMDB large movie review dataset is a binary classification dataset—all the reviews have either a positive or negative sentiment.

Download the dataset using TFDS. The dataset comes with an inbuilt subword tokenizer.

dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,
                          as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
Downloading and preparing dataset imdb_reviews (80.23 MiB) to /home/kbuilder/tensorflow_datasets/imdb_reviews/subwords8k/0.1.0...

HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Completed...', max=1, style=ProgressStyl…
HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Size...', max=1, style=ProgressStyle(des…




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


HBox(children=(IntProgress(value=0, description='Shuffling...', max=10, style=ProgressStyle(description_width=…
WARNING: Logging before flag parsing goes to stderr.
W0813 07:00:39.302064 139794992858880 deprecation.py:323] From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


HBox(children=(IntProgress(value=0, description='Shuffling...', max=10, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


HBox(children=(IntProgress(value=0, description='Shuffling...', max=20, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=2500, style=ProgressStyle(description_width=…
Dataset imdb_reviews downloaded and prepared to /home/kbuilder/tensorflow_datasets/imdb_reviews/subwords8k/0.1.0. Subsequent calls will reuse this data.

W0813 07:02:21.999113 139794992858880 dataset_builder.py:439] Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.

As this is a subwords tokenizer, it can be passed any string and the tokenizer will tokenize it.

tokenizer = info.features['text'].encoder
print ('Vocabulary size: {}'.format(tokenizer.vocab_size))
Vocabulary size: 8185
sample_string = 'TensorFlow is cool.'

tokenized_string = tokenizer.encode(sample_string)
print ('Tokenized string is {}'.format(tokenized_string))

original_string = tokenizer.decode(tokenized_string)
print ('The original string: {}'.format(original_string))

assert original_string == sample_string
Tokenized string is [6307, 2327, 4043, 4265, 9, 2724, 7975]
The original string: TensorFlow is cool.

The tokenizer encodes the string by breaking it into subwords if the word is not in its dictionary.

for ts in tokenized_string:
  print ('{} ----> {}'.format(ts, tokenizer.decode([ts])))
6307 ----> Ten
2327 ----> sor
4043 ----> Fl
4265 ----> ow 
9 ----> is 
2724 ----> cool
7975 ----> .
BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)

test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)

Create the model

Build a tf.keras.Sequential model and start with an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.

This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a tf.keras.layers.Dense layer.

A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input—and then to the next.

The tf.keras.layers.Bidirectional wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the output. This helps the RNN to learn long range dependencies.

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokenizer.vocab_size, 64),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

Compile the Keras model to configure the training process:

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

Train the model

history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset)
Epoch 1/10

W0813 07:02:26.823703 139794992858880 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

391/391 [==============================] - 336s 858ms/step - loss: 0.5780 - accuracy: 0.6907 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/10
391/391 [==============================] - 126s 322ms/step - loss: 0.3769 - accuracy: 0.8449 - val_loss: 0.3847 - val_accuracy: 0.8410
Epoch 3/10
391/391 [==============================] - 93s 237ms/step - loss: 0.2908 - accuracy: 0.8882 - val_loss: 0.3959 - val_accuracy: 0.8292
Epoch 4/10
391/391 [==============================] - 95s 244ms/step - loss: 0.2279 - accuracy: 0.9152 - val_loss: 0.4133 - val_accuracy: 0.8310
Epoch 5/10
391/391 [==============================] - 84s 215ms/step - loss: 0.2112 - accuracy: 0.9246 - val_loss: 0.5196 - val_accuracy: 0.8291
Epoch 6/10
391/391 [==============================] - 80s 205ms/step - loss: 0.1731 - accuracy: 0.9381 - val_loss: 0.8120 - val_accuracy: 0.7032
Epoch 7/10
391/391 [==============================] - 81s 206ms/step - loss: 0.3342 - accuracy: 0.8620 - val_loss: 0.5436 - val_accuracy: 0.8412
Epoch 8/10
391/391 [==============================] - 75s 191ms/step - loss: 0.2024 - accuracy: 0.9196 - val_loss: 0.5575 - val_accuracy: 0.7104
Epoch 9/10
391/391 [==============================] - 76s 194ms/step - loss: 0.2000 - accuracy: 0.9260 - val_loss: 0.5066 - val_accuracy: 0.8311
Epoch 10/10
391/391 [==============================] - 72s 183ms/step - loss: 0.1234 - accuracy: 0.9587 - val_loss: 0.6259 - val_accuracy: 0.8187
test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss: {}'.format(test_loss))
print('Test Accuracy: {}'.format(test_acc))
    391/Unknown - 19s 50ms/step - loss: 0.6259 - accuracy: 0.8187Test Loss: 0.6259316261238455
Test Accuracy: 0.8186799883842468

The above model does not mask the padding applied to the sequences. This can lead to skewness if we train on padded sequences and test on un-padded sequences. Ideally the model would learn to ignore the padding, but as you can see below it does have a small effect on the output.

If the prediction is >= 0.5, it is positive else it is negative.

def pad_to_size(vec, size):
  zeros = [0] * (size - len(vec))
  vec.extend(zeros)
  return vec
def sample_predict(sentence, pad):
  tokenized_sample_pred_text = tokenizer.encode(sample_pred_text)

  if pad:
    tokenized_sample_pred_text = pad_to_size(tokenized_sample_pred_text, 64)

  predictions = model.predict(tf.expand_dims(tokenized_sample_pred_text, 0))

  return (predictions)
# predict on a sample text without padding.

sample_pred_text = ('The movie was cool. The animation and the graphics '
                    'were out of this world. I would recommend this movie.')
predictions = sample_predict(sample_pred_text, pad=False)
print (predictions)
[[0.73671955]]
# predict on a sample text with padding

sample_pred_text = ('The movie was cool. The animation and the graphics '
                    'were out of this world. I would recommend this movie.')
predictions = sample_predict(sample_pred_text, pad=True)
print (predictions)
[[0.55615026]]
plot_graphs(history, 'accuracy')

png

plot_graphs(history, 'loss')

png

Stack two or more LSTM layers

Keras recurrent layers have two available modes that are controlled by the return_sequences constructor argument:

  • Return either the full sequences of successive outputs for each timestep (a 3D tensor of shape (batch_size, timesteps, output_features)).
  • Return only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)).
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokenizer.vocab_size, 64),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(
        64, return_sequences=True)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset)
Epoch 1/10
391/391 [==============================] - 606s 2s/step - loss: 0.5953 - accuracy: 0.6839 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/10
391/391 [==============================] - 226s 578ms/step - loss: 0.5264 - accuracy: 0.7454 - val_loss: 0.4971 - val_accuracy: 0.7708
Epoch 3/10
391/391 [==============================] - 174s 445ms/step - loss: 0.4931 - accuracy: 0.7726 - val_loss: 0.4993 - val_accuracy: 0.7750
Epoch 4/10
391/391 [==============================] - 162s 415ms/step - loss: 0.3759 - accuracy: 0.8433 - val_loss: 0.3994 - val_accuracy: 0.8286
Epoch 5/10
391/391 [==============================] - 154s 394ms/step - loss: 0.2852 - accuracy: 0.8898 - val_loss: 0.3661 - val_accuracy: 0.8440
Epoch 6/10
391/391 [==============================] - 137s 349ms/step - loss: 0.2313 - accuracy: 0.9134 - val_loss: 0.3700 - val_accuracy: 0.8483
Epoch 7/10
391/391 [==============================] - 140s 357ms/step - loss: 0.1855 - accuracy: 0.9359 - val_loss: 0.3362 - val_accuracy: 0.8681
Epoch 8/10
391/391 [==============================] - 136s 348ms/step - loss: 0.1466 - accuracy: 0.9526 - val_loss: 0.3894 - val_accuracy: 0.8588
Epoch 9/10
391/391 [==============================] - 133s 340ms/step - loss: 0.1068 - accuracy: 0.9677 - val_loss: 0.4080 - val_accuracy: 0.8648
Epoch 10/10
391/391 [==============================] - 130s 332ms/step - loss: 0.0845 - accuracy: 0.9750 - val_loss: 0.4726 - val_accuracy: 0.8665
test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss: {}'.format(test_loss))
print('Test Accuracy: {}'.format(test_acc))
    391/Unknown - 35s 89ms/step - loss: 0.4726 - accuracy: 0.8665Test Loss: 0.4725629985904145
Test Accuracy: 0.8664799928665161
# predict on a sample text without padding.

sample_pred_text = ('The movie was not good. The animation and the graphics '
                    'were terrible. I would not recommend this movie.')
predictions = sample_predict(sample_pred_text, pad=False)
print (predictions)
[[0.00248089]]
# predict on a sample text with padding

sample_pred_text = ('The movie was not good. The animation and the graphics '
                    'were terrible. I would not recommend this movie.')
predictions = sample_predict(sample_pred_text, pad=True)
print (predictions)
[[0.00362942]]
plot_graphs(history, 'accuracy')

png

plot_graphs(history, 'loss')

png

Check out other existing recurrent layers such as GRU layers.