Text classification with TensorFlow Lite Model Maker

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.

This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used text classification model to classify movie reviews on a mobile device. The text classification model classifies text into predefined categories.The inputs should be preprocessed text and the outputs are the probabilities of the categories. The dataset used in this tutorial are positive and negative movie reviews.

Prerequisites

Install the required packages

To run this example, install the required packages, including the Model Maker package from the GitHub repo.

pip install -q git+https://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]

Import the required packages.

import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_maker.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_maker.core.task import model_spec
from tensorflow_examples.lite.model_maker.core.task import text_classifier
from tensorflow_examples.lite.model_maker.core.task.configs import QuantizationConfig
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_addons/utils/ensure_tf_install.py:44: UserWarning: You are currently using a nightly version of TensorFlow (2.4.0-dev20200806). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Get the data path

Download the dataset for this tutorial.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8
7446528/7439277 [==============================] - 0s 0us/step

You can also upload your own dataset to work through this tutorial. Upload your dataset by using the left sidebar in Colab.

Upload File

If you prefer not to upload your dataset to the cloud, you can also locally run the library by following the guide.

End-to-End Workflow

This workflow consists of five steps as outlined below:

Step 1. Choose a model specification that represents a text classification model.

This tutorial uses MobileBERT as an example.

spec = model_spec.get('mobilebert_classifier')

Step 2. Load train and test data specific to an on-device ML app and preprocess the data according to a specific model_spec.

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

Step 3. Customize the TensorFlow model.

model = text_classifier.create(train_data, model_spec=spec)
INFO:tensorflow:Retraining the models...

INFO:tensorflow:Retraining the models...

Warning:tensorflow:Automatic model reloading for interrupted job was removed from the `ModelCheckpoint` callback in multi-worker mode, please use the `keras.callbacks.experimental.BackupAndRestore` callback instead. See this tutorial for details: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#backupandrestore_callback.

Warning:tensorflow:Automatic model reloading for interrupted job was removed from the `ModelCheckpoint` callback in multi-worker mode, please use the `keras.callbacks.experimental.BackupAndRestore` callback instead. See this tutorial for details: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#backupandrestore_callback.

Epoch 1/3
   1/1403 [..............................] - ETA: 19:23:35 - loss: 5.6609 - test_accuracy: 0.4583WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1371: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1371: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

1403/1403 [==============================] - 291s 207ms/step - loss: 0.9926 - test_accuracy: 0.6721
Epoch 2/3
1403/1403 [==============================] - 242s 172ms/step - loss: 0.2191 - test_accuracy: 0.9135
Epoch 3/3
1403/1403 [==============================] - 242s 173ms/step - loss: 0.1521 - test_accuracy: 0.9440

Step 4. Evaluate the model.

loss, acc = model.evaluate(test_data)
28/28 [==============================] - 6s 232ms/step - loss: 0.2862 - test_accuracy: 0.9014

Step 5. Export as a TensorFlow Lite model.

Since MobileBERT is too big for on-device applications, use dynamic range quantization on the model to compress it by almost 4x with minimal performance degradation.

config = QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config._experimental_new_quantizer = True
model.export(export_dir='mobilebert/', quantization_config=config)
INFO:tensorflow:Saving labels in mobilebert/labels.txt.

INFO:tensorflow:Saving labels in mobilebert/labels.txt.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/tmpf2vfeos1/saved_model/assets

INFO:tensorflow:Assets written to: /tmp/tmpf2vfeos1/saved_model/assets

INFO:tensorflow:Saved vocabulary in mobilebert/vocab.

INFO:tensorflow:Saved vocabulary in mobilebert/vocab.

You can also download the model using the left sidebar in Colab.

After executing the 5 steps above, you can further use the TensorFlow Lite model file and label file in on-device applications like in a text classification reference app.

The following sections walk through the example step by step to show more detail.

Choose a model_spec that Represents a Model for Text Classifier

Each model_spec object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports MobileBERT, averaging word embeddings and [BERT-Base]((https://arxiv.org/pdf/1810.04805.pdf) models.

Supported Model Name of model_spec Model Description
MobileBERT 'mobilebert_classifier' 4.3x smaller and 5.5x faster than BERT-Base while achieving competitive results, suitable for on-device applications.
BERT-Base 'bert_classifier' Standard BERT model that is widely used in NLP tasks.
averaging word embedding 'average_word_vec' Averaging text word embeddings with RELU activation.

This tutorial uses a smaller model, average_word_vec that you can retrain multiple times to demonstrate the process.

spec = model_spec.get('average_word_vec')

Load Input Data Specific to an On-device ML App

The SST-2 (Stanford Sentiment Treebank) is one of the tasks in the GLUE benchmark . It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.

Download the archived version of the dataset and extract it.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')

The SST-2 dataset has train.tsv for training and dev.tsv for validation. The files have the following format:

sentence label
it 's a charming and often affecting journey . 1
unflinchingly bleak and desperate 0

A positive review is labeled 1 and a negative review is labeled 0.

Use the TestClassifierDataLoader.from_csv method to load the data.

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)
INFO:tensorflow:Saved vocabulary in /tmp/tmpb4_wdf3d/e72d242a17446c5dc91aa41e181ce914_vocab.

INFO:tensorflow:Saved vocabulary in /tmp/tmpb4_wdf3d/e72d242a17446c5dc91aa41e181ce914_vocab.

The Model Maker library also supports the from_folder() method to load data. It assumes that the text data of the same class are in the same subdirectory and that the subfolder name is the class name. Each text file contains one movie review sample. The class_labels parameter is used to specify which the subfolders.

Customize the TensorFlow Model

Create a custom text classifier model based on the loaded data.

model = text_classifier.create(train_data, model_spec=spec, epochs=10)
INFO:tensorflow:Retraining the models...

INFO:tensorflow:Retraining the models...

Epoch 1/10
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6864 - accuracy: 0.5541
Epoch 2/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.6171 - accuracy: 0.6663
Epoch 3/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4733 - accuracy: 0.7768
Epoch 4/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4087 - accuracy: 0.8170
Epoch 5/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3797 - accuracy: 0.8344
Epoch 6/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3636 - accuracy: 0.8459
Epoch 7/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3516 - accuracy: 0.8550
Epoch 8/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3415 - accuracy: 0.8584
Epoch 9/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3350 - accuracy: 0.8628
Epoch 10/10
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3280 - accuracy: 0.8656

Examine the detailed model structure.

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 256, 16)           160048    
_________________________________________________________________
global_average_pooling1d (Gl (None, 16)                0         
_________________________________________________________________
dense (Dense)                (None, 16)                272       
_________________________________________________________________
dropout_1 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 34        
=================================================================
Total params: 160,354
Trainable params: 160,354
Non-trainable params: 0
_________________________________________________________________

Evaluate the Customized Model

Evaluate the result of the model and get the loss and accuracy of the model.

Evaluate the loss and accuracy in the test data.

loss, acc = model.evaluate(test_data)
28/28 [==============================] - 0s 6ms/step - loss: 0.5239 - accuracy: 0.8314

Export as a TensorFlow Lite Model

Convert the existing model to TensorFlow Lite model format that you can later use in an on-device ML application. Save the text labels in a label file and vocabulary in a vocab file. The default TFLite filename is model.tflite, the default label filename is label.txt and the default vocab filename is vocab.

model.export(export_dir='average_word_vec/')
INFO:tensorflow:Saving labels in average_word_vec/labels.txt.

INFO:tensorflow:Saving labels in average_word_vec/labels.txt.

INFO:tensorflow:Assets written to: /tmp/tmpy38mxo35/assets

INFO:tensorflow:Assets written to: /tmp/tmpy38mxo35/assets

INFO:tensorflow:Saved vocabulary in average_word_vec/vocab.

INFO:tensorflow:Saved vocabulary in average_word_vec/vocab.

The TensorFlow Lite model file and label file can be used in the text classification reference app by adding model.tflite, text_label.txt and vocab.txt to the assets directory. Do not forget to also change the filenames in the code.

You can evalute the tflite model with evaluate_tflite method.

model.evaluate_tflite('average_word_vec/model.tflite', test_data)
INFO:tensorflow:Processing example: #0
[[   1   12    8    4  300    5  145 1500  622    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0]]

INFO:tensorflow:Processing example: #0
[[   1   12    8    4  300    5  145 1500  622    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0]]

{'accuracy': 0.8314220183486238}

Advanced Usage

The create function is the driver function that the Model Maker library uses to create models. The model spec parameter defines the model specification. The AverageWordVecModelSpec and BertClassifierModelSpec classes are currently supported. The create function comprises of the following steps:

  1. Creates the model for the text classifier according to model_spec.
  2. Trains the classifier model. The default epochs and the default batch size are set by the default_training_epochs and default_batch_size variables in the model_spec object.

This section covers advanced usage topics like adjusting the model and the training hyperparameters.

Adjust the model

You can adjust the model infrastructure like the wordvec_dim and the seq_len variables in the AverageWordVecModelSpec class.

For example, you can train the model with a larger value of wordvec_dim. Note that you must construct a new model_spec if you modify the model.

new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)

Get the preprocessed data.

new_train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      delimiter='\t',
      is_training=True)
INFO:tensorflow:Saved vocabulary in /tmp/tmp5211dmkl/997a3ae24c002f2be0c24669784bb1ce_vocab.

INFO:tensorflow:Saved vocabulary in /tmp/tmp5211dmkl/997a3ae24c002f2be0c24669784bb1ce_vocab.

Train the new model.

model = text_classifier.create(new_train_data, model_spec=new_model_spec)
INFO:tensorflow:Retraining the models...

INFO:tensorflow:Retraining the models...

Epoch 1/3
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6829 - accuracy: 0.5583
Epoch 2/3
2104/2104 [==============================] - 7s 3ms/step - loss: 0.5239 - accuracy: 0.7455
Epoch 3/3
2104/2104 [==============================] - 7s 4ms/step - loss: 0.4071 - accuracy: 0.8165

You can also adjust the MobileBERT model.

The model parameters you can adjust are:

  • seq_len: Length of the sequence to feed into the model.
  • initializer_range: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • trainable: Boolean that specifies whether the pre-trained layer is trainable.

The training pipeline parameters you can adjust are:

  • model_dir: The location of the model checkpoint files. If not set, a temporary directory will be used.
  • dropout_rate: The dropout rate.
  • learning_rate: The initial learning rate for the Adam optimizer.
  • tpu: TPU address to connect to.

For instance, you can set the seq_len=256 (default is 128). This allows the model to classify longer text.

new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256

Tune the training hyperparameters

You can also tune the training hyperparameters like epochs and batch_size that affect the model accuracy. For instance,

  • epochs: more epochs could achieve better accuracy, but may lead to overfitting.
  • batch_size: the number of samples to use in one training step.

For example, you can train with more epochs.

model = text_classifier.create(train_data, model_spec=spec, epochs=20)
INFO:tensorflow:Retraining the models...

INFO:tensorflow:Retraining the models...

Epoch 1/20
2104/2104 [==============================] - 7s 3ms/step - loss: 0.6857 - accuracy: 0.5551
Epoch 2/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.6241 - accuracy: 0.6597
Epoch 3/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4934 - accuracy: 0.7734
Epoch 4/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.4280 - accuracy: 0.8071
Epoch 5/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3973 - accuracy: 0.8259
Epoch 6/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3777 - accuracy: 0.8350
Epoch 7/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3626 - accuracy: 0.8432
Epoch 8/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3507 - accuracy: 0.8515
Epoch 9/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3441 - accuracy: 0.8560
Epoch 10/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3362 - accuracy: 0.8593
Epoch 11/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3300 - accuracy: 0.8632
Epoch 12/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3237 - accuracy: 0.8667
Epoch 13/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3217 - accuracy: 0.8663
Epoch 14/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3143 - accuracy: 0.8712
Epoch 15/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3138 - accuracy: 0.8718
Epoch 16/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3097 - accuracy: 0.8741
Epoch 17/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3076 - accuracy: 0.8771
Epoch 18/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3069 - accuracy: 0.8768
Epoch 19/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3042 - accuracy: 0.8767
Epoch 20/20
2104/2104 [==============================] - 6s 3ms/step - loss: 0.3013 - accuracy: 0.8797

Evaluate the newly retrained model with 20 training epochs.

loss, accuracy = model.evaluate(test_data)
28/28 [==============================] - 0s 6ms/step - loss: 0.5095 - accuracy: 0.8268

Change the Model Architecture

You can change the model by changing the model_spec. The following shows how to change to BERT-Base model.

Change the model_spec to BERT-Base model for the text classifier.

spec = model_spec.get('bert_classifier')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

The remaining steps are the same.