Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

Text classification with an RNN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This text classification tutorial trains a recurrent neural network on the IMDB large movie review dataset for sentiment analysis.

Setup

import numpy as np

import tensorflow_datasets as tfds
import tensorflow as tf

tfds.disable_progress_bar()
2022-12-14 13:38:21.991446: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 13:38:21.991557: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 13:38:21.991567: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Import matplotlib and create a helper function to plot graphs:

import matplotlib.pyplot as plt


def plot_graphs(history, metric):
  plt.plot(history.history[metric])
  plt.plot(history.history['val_'+metric], '')
  plt.xlabel("Epochs")
  plt.ylabel(metric)
  plt.legend([metric, 'val_'+metric])

Setup input pipeline

The IMDB large movie review dataset is a binary classification dataset—all the reviews have either a positive or negative sentiment.

Download the dataset using TFDS. See the loading text tutorial for details on how to load this sort of data manually.

dataset, info = tfds.load('imdb_reviews', with_info=True,
                          as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

train_dataset.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None),
 TensorSpec(shape=(), dtype=tf.int64, name=None))

Initially this returns a dataset of (text, label pairs):

for example, label in train_dataset.take(1):
  print('text: ', example.numpy())
  print('label: ', label.numpy())
text:  b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
label:  0

Next shuffle the data for training and create batches of these (text, label) pairs:

BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
for example, label in train_dataset.take(1):
  print('texts: ', example.numpy()[:3])
  print()
  print('labels: ', label.numpy()[:3])
texts:  [b"As a huge fan of horror films, especially J-horror and also gore i thought Nekeddo bur\xc3\xa2ddo sounded pretty good. I researched the plot, read reviews, and even looked at some photos to make sure it seemed like a good gory and scary movie to watch before downloading it. So excited it had finished and ready to be scared and recoiling in horror at the amazing gore i was expecting i was terribly disappointed. The plot was ridiculous and didn't even make sense and left too much unexplained, the gore was hilarious rather then horrifying, and what was with the cartoon style sound effects ? The acting was probably the only thing mildly scary about it. I did not understand the cactus idea and the way the mothers husband disappeared in the middle of the sea after following a flashing light, they left both pretty unexplained, or perhaps i missed it as my mind couldn't understand what i was actually seeing. I appreciate the way it was supposed to be; shocking and a few scenes (the strange cannibalism and own mother kissing?)certainly were, i just think they went a little bit far and not even in a horrifying way, they made it to unconvincing which made it more believable to be a comedy rather than a horror in my opinion. However it is a very entertaining film and got a lot of laughs out of me and a couple of friends, but sadly we were expecting horror not comedy so its worth a watch for the entertainment value, but don't be expecting a dark, deeply scary and horrifying film; you'll just be disappointed. If this was a horror comedy/spoof i'd probably rate it about a nine, the climax being the weird scene when the husband climbed inside his wife's stomach and closed up her wounds, but as a horror sadly i gave it a one."
 b'"What happens when you give a homeless man \\(100,000?" As if by asking that question they are somehow morally absolved of what is eventually going to happen. The creators of "Reversal of Fortune" try to get their voyeuristic giggles while disguising their antics as some kind of responsible social experiment.<br /><br />They take Ted, a homeless man in Pasadena, and give him \\)100,000 to see if he will turn his life around. Then, with only the most cursory guidance and counseling, they let him go on his merry way.<br /><br />What are they trying to say? "Money can\'t buy you happiness?" "The homeless are homeless because they deserve to be?" Or how about, "Lift a man up - it\'s more fun to watch him fall from a greater altitude." They took a man with nothing to lose, gave him something to lose, and then watched him dump it all down the drain. That\'s supposed to be entertainment? They dress this sow up with some gloomy music and dramatic camera shots, but in the end it has all the moral high ground of car crash videos - only this time they engineered the car crashes and asked, "What happens when you take down a stop sign?"'
 b"'Mojo' is a story of fifties London, a world of budding rock stars, violence and forced homosexuality. 'Mojo' uses a technique for shooting the 1950s often seen in films that stresses the physical differences to our own time but also represents dialogue in a highly exaggerated fashion (owing much to the way that speech was represented in films made in that period); I have no idea if people actually spoke like this outside of the movies, but no films made today and set in contemporary times use such stylised language. It's as if the stilted discourse of 1950s screenwriters serves a common shorthand for a past that seems, in consequence, a very distant country indeed; and therefore stresses the particular, rather than the universal, in the story. 'Mojo' features a strong performance from Ian Hart and annoying ones from Aiden Gillan and Ewan Bremner, the latter still struggling to build a post-'Trainspotting' career; but feels like a period piece, a modern film incomprehensibly structured in an outdated idiom. Rather dull, actually."]

labels:  [0 0 0]

Create the text encoder

The raw text loaded by tfds needs to be processed before it can be used in a model. The simplest way to process text for training is using the TextVectorization layer. This layer has many capabilities, but this tutorial sticks to the default behavior.

Create the layer, and pass the dataset's text to the layer's .adapt method:

VOCAB_SIZE = 1000
encoder = tf.keras.layers.TextVectorization(
    max_tokens=VOCAB_SIZE)
encoder.adapt(train_dataset.map(lambda text, label: text))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

The .adapt method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency:

vocab = np.array(encoder.get_vocabulary())
vocab[:20]
array(['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i',
       'this', 'that', 'br', 'was', 'as', 'for', 'with', 'movie', 'but'],
      dtype='<U14')

Once the vocabulary is set, the layer can encode text into indices. The tensors of indices are 0-padded to the longest sequence in the batch (unless you set a fixed output_sequence_length):

encoded_example = encoder(example)[:3].numpy()
encoded_example
array([[ 15,   4, 629, ...,   0,   0,   0],
       [ 49, 557,  51, ...,   0,   0,   0],
       [  1,   7,   4, ...,   0,   0,   0]])

With the default settings, the process is not completely reversible. There are three main reasons for that:

  1. The default value for preprocessing.TextVectorization's standardize argument is "lower_and_strip_punctuation".
  2. The limited vocabulary size and lack of character-based fallback results in some unknown tokens.
for n in range(3):
  print("Original: ", example[n].numpy())
  print("Round-trip: ", " ".join(vocab[encoded_example[n]]))
  print()
Original:  b"As a huge fan of horror films, especially J-horror and also gore i thought Nekeddo bur\xc3\xa2ddo sounded pretty good. I researched the plot, read reviews, and even looked at some photos to make sure it seemed like a good gory and scary movie to watch before downloading it. So excited it had finished and ready to be scared and recoiling in horror at the amazing gore i was expecting i was terribly disappointed. The plot was ridiculous and didn't even make sense and left too much unexplained, the gore was hilarious rather then horrifying, and what was with the cartoon style sound effects ? The acting was probably the only thing mildly scary about it. I did not understand the cactus idea and the way the mothers husband disappeared in the middle of the sea after following a flashing light, they left both pretty unexplained, or perhaps i missed it as my mind couldn't understand what i was actually seeing. I appreciate the way it was supposed to be; shocking and a few scenes (the strange cannibalism and own mother kissing?)certainly were, i just think they went a little bit far and not even in a horrifying way, they made it to unconvincing which made it more believable to be a comedy rather than a horror in my opinion. However it is a very entertaining film and got a lot of laughs out of me and a couple of friends, but sadly we were expecting horror not comedy so its worth a watch for the entertainment value, but don't be expecting a dark, deeply scary and horrifying film; you'll just be disappointed. If this was a horror comedy/spoof i'd probably rate it about a nine, the climax being the weird scene when the husband climbed inside his wife's stomach and closed up her wounds, but as a horror sadly i gave it a one."
Round-trip:  as a huge fan of horror films especially [UNK] and also gore i thought [UNK] [UNK] [UNK] pretty good i [UNK] the plot read reviews and even looked at some [UNK] to make sure it seemed like a good [UNK] and scary movie to watch before [UNK] it so [UNK] it had [UNK] and [UNK] to be [UNK] and [UNK] in horror at the amazing gore i was expecting i was [UNK] disappointed the plot was ridiculous and didnt even make sense and left too much [UNK] the gore was hilarious rather then [UNK] and what was with the [UNK] style sound effects the acting was probably the only thing [UNK] scary about it i did not understand the [UNK] idea and the way the [UNK] husband [UNK] in the middle of the [UNK] after [UNK] a [UNK] light they left both pretty [UNK] or perhaps i [UNK] it as my mind couldnt understand what i was actually seeing i [UNK] the way it was supposed to be [UNK] and a few scenes the strange [UNK] and own mother [UNK] were i just think they went a little bit far and not even in a [UNK] way they made it to [UNK] which made it more believable to be a comedy rather than a horror in my opinion however it is a very entertaining film and got a lot of laughs out of me and a couple of friends but [UNK] we were expecting horror not comedy so its worth a watch for the entertainment [UNK] but dont be expecting a dark [UNK] scary and [UNK] film youll just be disappointed if this was a horror [UNK] id probably [UNK] it about a [UNK] the [UNK] being the weird scene when the husband [UNK] inside his [UNK] [UNK] and [UNK] up her [UNK] but as a horror [UNK] i gave it a one                                                                                                                                                                                                                                                                                     

Original:  b'"What happens when you give a homeless man \\(100,000?" As if by asking that question they are somehow morally absolved of what is eventually going to happen. The creators of "Reversal of Fortune" try to get their voyeuristic giggles while disguising their antics as some kind of responsible social experiment.<br /><br />They take Ted, a homeless man in Pasadena, and give him \\)100,000 to see if he will turn his life around. Then, with only the most cursory guidance and counseling, they let him go on his merry way.<br /><br />What are they trying to say? "Money can\'t buy you happiness?" "The homeless are homeless because they deserve to be?" Or how about, "Lift a man up - it\'s more fun to watch him fall from a greater altitude." They took a man with nothing to lose, gave him something to lose, and then watched him dump it all down the drain. That\'s supposed to be entertainment? They dress this sow up with some gloomy music and dramatic camera shots, but in the end it has all the moral high ground of car crash videos - only this time they engineered the car crashes and asked, "What happens when you take down a stop sign?"'
Round-trip:  what happens when you give a [UNK] man [UNK] as if by [UNK] that question they are somehow [UNK] [UNK] of what is eventually going to happen the [UNK] of [UNK] of [UNK] try to get their [UNK] [UNK] while [UNK] their [UNK] as some kind of [UNK] [UNK] [UNK] br they take [UNK] a [UNK] man in [UNK] and give him [UNK] to see if he will turn his life around then with only the most [UNK] [UNK] and [UNK] they let him go on his [UNK] [UNK] br what are they trying to say money cant buy you [UNK] the [UNK] are [UNK] because they [UNK] to be or how about [UNK] a man up its more fun to watch him fall from a [UNK] [UNK] they took a man with nothing to [UNK] gave him something to [UNK] and then watched him [UNK] it all down the [UNK] thats supposed to be entertainment they [UNK] this [UNK] up with some [UNK] music and dramatic camera shots but in the end it has all the [UNK] high [UNK] of car [UNK] [UNK] only this time they [UNK] the car [UNK] and [UNK] what happens when you take down a stop [UNK]                                                                                                                                                                                                                                                                                                                                                                                                   

Original:  b"'Mojo' is a story of fifties London, a world of budding rock stars, violence and forced homosexuality. 'Mojo' uses a technique for shooting the 1950s often seen in films that stresses the physical differences to our own time but also represents dialogue in a highly exaggerated fashion (owing much to the way that speech was represented in films made in that period); I have no idea if people actually spoke like this outside of the movies, but no films made today and set in contemporary times use such stylised language. It's as if the stilted discourse of 1950s screenwriters serves a common shorthand for a past that seems, in consequence, a very distant country indeed; and therefore stresses the particular, rather than the universal, in the story. 'Mojo' features a strong performance from Ian Hart and annoying ones from Aiden Gillan and Ewan Bremner, the latter still struggling to build a post-'Trainspotting' career; but feels like a period piece, a modern film incomprehensibly structured in an outdated idiom. Rather dull, actually."
Round-trip:  [UNK] is a story of [UNK] [UNK] a world of [UNK] rock stars violence and forced [UNK] [UNK] [UNK] a [UNK] for [UNK] the [UNK] often seen in films that [UNK] the [UNK] [UNK] to our own time but also [UNK] dialogue in a highly [UNK] [UNK] [UNK] much to the way that [UNK] was [UNK] in films made in that period i have no idea if people actually [UNK] like this outside of the movies but no films made today and set in [UNK] times use such [UNK] [UNK] its as if the [UNK] [UNK] of [UNK] [UNK] [UNK] a [UNK] [UNK] for a past that seems in [UNK] a very [UNK] country indeed and [UNK] [UNK] the particular rather than the [UNK] in the story [UNK] features a strong performance from [UNK] [UNK] and annoying ones from [UNK] [UNK] and [UNK] [UNK] the [UNK] still [UNK] to [UNK] a [UNK] career but feels like a period piece a modern film [UNK] [UNK] in an [UNK] [UNK] rather dull actually

Create the model

A drawing of the information flow in the model

Above is a diagram of the model.

  1. This model can be build as a tf.keras.Sequential.

  2. The first layer is the encoder, which converts the text to a sequence of token indices.

  3. After the encoder is an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.

    This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a tf.keras.layers.Dense layer.

  4. A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input on the next timestep.

    The tf.keras.layers.Bidirectional wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the final output.

    • The main advantage of a bidirectional RNN is that the signal from the beginning of the input doesn't need to be processed all the way through every timestep to affect the output.

    • The main disadvantage of a bidirectional RNN is that you can't efficiently stream predictions as words are being added to the end.

  5. After the RNN has converted the sequence to a single vector the two layers.Dense do some final processing, and convert from this vector representation to a single logit as the classification output.

The code to implement this is below:

model = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(
        input_dim=len(encoder.get_vocabulary()),
        output_dim=64,
        # Use masking to handle the variable sequence lengths
        mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])

Please note that Keras sequential model is used here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check Keras RNN guide for more details.

The embedding layer uses masking to handle the varying sequence-lengths. All the layers after the Embedding support masking:

print([layer.supports_masking for layer in model.layers])
[False, True, True, True, True]

To confirm that this works as expected, evaluate a sentence twice. First, alone so there's no padding to mask:

# predict on a sample text without padding.

sample_text = ('The movie was cool. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions[0])
1/1 [==============================] - 3s 3s/step
[-0.00252648]

Now, evaluate it again in a batch with a longer sentence. The result should be identical:

# predict on a sample text with padding

padding = "the " * 2000
predictions = model.predict(np.array([sample_text, padding]))
print(predictions[0])
1/1 [==============================] - 0s 60ms/step
[-0.00252648]

Compile the Keras model to configure the training process:

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])

Train the model

history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset,
                    validation_steps=30)
Epoch 1/10
2022-12-14 13:38:42.916496: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_FLOAT
    }
  }
}

    while inferring type of node 'cond_40/output/_23'
391/391 [==============================] - 45s 94ms/step - loss: 0.6407 - accuracy: 0.5861 - val_loss: 0.5107 - val_accuracy: 0.7151
Epoch 2/10
391/391 [==============================] - 26s 64ms/step - loss: 0.4621 - accuracy: 0.7836 - val_loss: 0.4241 - val_accuracy: 0.8224
Epoch 3/10
391/391 [==============================] - 25s 62ms/step - loss: 0.3893 - accuracy: 0.8248 - val_loss: 0.3739 - val_accuracy: 0.8276
Epoch 4/10
391/391 [==============================] - 25s 62ms/step - loss: 0.3656 - accuracy: 0.8428 - val_loss: 0.3558 - val_accuracy: 0.8484
Epoch 5/10
391/391 [==============================] - 25s 64ms/step - loss: 0.3301 - accuracy: 0.8574 - val_loss: 0.3418 - val_accuracy: 0.8406
Epoch 6/10
391/391 [==============================] - 25s 63ms/step - loss: 0.3197 - accuracy: 0.8630 - val_loss: 0.3298 - val_accuracy: 0.8474
Epoch 7/10
391/391 [==============================] - 24s 61ms/step - loss: 0.3119 - accuracy: 0.8678 - val_loss: 0.3347 - val_accuracy: 0.8620
Epoch 8/10
391/391 [==============================] - 24s 61ms/step - loss: 0.3062 - accuracy: 0.8694 - val_loss: 0.3240 - val_accuracy: 0.8531
Epoch 9/10
391/391 [==============================] - 24s 61ms/step - loss: 0.3029 - accuracy: 0.8709 - val_loss: 0.3232 - val_accuracy: 0.8547
Epoch 10/10
391/391 [==============================] - 24s 60ms/step - loss: 0.2990 - accuracy: 0.8733 - val_loss: 0.3240 - val_accuracy: 0.8542
test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 10s 25ms/step - loss: 0.3163 - accuracy: 0.8582
Test Loss: 0.3162669241428375
Test Accuracy: 0.8581600189208984
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.ylim(None, 1)
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
plt.ylim(0, None)
(0.0, 0.6577653184533119)

png

Run a prediction on a new sentence:

If the prediction is >= 0.0, it is positive else it is negative.

sample_text = ('The movie was cool. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
1/1 [==============================] - 2s 2s/step

Stack two or more LSTM layers

Keras recurrent layers have two available modes that are controlled by the return_sequences constructor argument:

  • If False it returns only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)). This is the default, used in the previous model.

  • If True the full sequences of successive outputs for each timestep is returned (a 3D tensor of shape (batch_size, timesteps, output_features)).

Here is what the flow of information looks like with return_sequences=True:

layered_bidirectional

The interesting thing about using an RNN with return_sequences=True is that the output still has 3-axes, like the input, so it can be passed to another RNN layer, like this:

model = tf.keras.Sequential([
    encoder,
    tf.keras.layers.Embedding(len(encoder.get_vocabulary()), 64, mask_zero=True),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64,  return_sequences=True)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1)
])
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10,
                    validation_data=test_dataset,
                    validation_steps=30)
Epoch 1/10
391/391 [==============================] - 70s 140ms/step - loss: 0.6376 - accuracy: 0.5797 - val_loss: 0.4587 - val_accuracy: 0.7839
Epoch 2/10
391/391 [==============================] - 45s 114ms/step - loss: 0.4014 - accuracy: 0.8225 - val_loss: 0.3681 - val_accuracy: 0.8490
Epoch 3/10
391/391 [==============================] - 45s 113ms/step - loss: 0.3448 - accuracy: 0.8537 - val_loss: 0.3328 - val_accuracy: 0.8495
Epoch 4/10
391/391 [==============================] - 44s 112ms/step - loss: 0.3282 - accuracy: 0.8614 - val_loss: 0.3508 - val_accuracy: 0.8552
Epoch 5/10
391/391 [==============================] - 44s 112ms/step - loss: 0.3187 - accuracy: 0.8631 - val_loss: 0.3302 - val_accuracy: 0.8536
Epoch 6/10
391/391 [==============================] - 44s 112ms/step - loss: 0.3105 - accuracy: 0.8683 - val_loss: 0.3199 - val_accuracy: 0.8562
Epoch 7/10
391/391 [==============================] - 44s 111ms/step - loss: 0.3069 - accuracy: 0.8712 - val_loss: 0.3272 - val_accuracy: 0.8438
Epoch 8/10
391/391 [==============================] - 44s 112ms/step - loss: 0.3063 - accuracy: 0.8707 - val_loss: 0.3232 - val_accuracy: 0.8583
Epoch 9/10
391/391 [==============================] - 44s 112ms/step - loss: 0.3021 - accuracy: 0.8719 - val_loss: 0.3328 - val_accuracy: 0.8536
Epoch 10/10
391/391 [==============================] - 44s 112ms/step - loss: 0.3003 - accuracy: 0.8721 - val_loss: 0.3196 - val_accuracy: 0.8589
test_loss, test_acc = model.evaluate(test_dataset)

print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 18s 46ms/step - loss: 0.3163 - accuracy: 0.8614
Test Loss: 0.31628698110580444
Test Accuracy: 0.8614000082015991
# predict on a sample text without padding.

sample_text = ('The movie was not good. The animation and the graphics '
               'were terrible. I would not recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions)
1/1 [==============================] - 4s 4s/step
[[-1.8099834]]
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')

png

Check out other existing recurrent layers such as GRU layers.

If you're interested in building custom RNNs, see the Keras RNN Guide.