![]() |
![]() |
![]() |
![]() |
The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.
This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used text classification model to classify movie reviews on a mobile device. The text classification model classifies text into predefined categories. The inputs should be preprocessed text and the outputs are the probabilities of the categories. The dataset used in this tutorial are positive and negative movie reviews.
Prerequisites
Install the required packages
To run this example, install the required packages, including the Model Maker package from the GitHub repo.
If you run this notebook on Colab, you may see an error message about tensorflowjs
and tensorflow-hub
version incompatibility. It is safe to ignore this error as we do not use tensorflowjs
in this workflow.
pip install -q tflite-model-maker
Import the required packages.
import numpy as np
import os
from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker import TextClassifierDataLoader
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
Download the sample training data.
In this tutorial, we will use the SST-2 (Stanford Sentiment Treebank) which is one of the tasks in the GLUE benchmark. It contains 67,349 movie reviews for training and 872 movie reviews for testing. The dataset has two classes: positive and negative movie reviews.
data_dir = tf.keras.utils.get_file(
fname='SST-2.zip',
origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip 7446528/7439277 [==============================] - 2s 0us/step
The SST-2 dataset is stored in TSV format. The only difference between TSV and CSV is that TSV uses a tab \t
character as its delimiter instead of a comma ,
in the CSV format.
Here are the first 5 lines of the training dataset. label=0 means negative, label=1 means positive.
sentence | label | |||
---|---|---|---|---|
hide new secretions from the parental units | 0 | |||
contains no wit , only labored gags | 0 | |||
that loves its characters and communicates something rather beautiful about human nature | 1 | |||
remains utterly satisfied to remain the same throughout | 0 | |||
on the worst revenge-of-the-nerds clichés the filmmakers could dredge up | 0 |
Next, we will load the dataset into a Pandas dataframe and change the current label names (0
and 1
) to a more human-readable ones (negative
and positive
) and use them for model training.
import pandas as pd
def replace_label(original_file, new_file):
# Load the original file to pandas. We need to specify the separator as
# '\t' as the training data is stored in TSV format
df = pd.read_csv(original_file, sep='\t')
# Define how we want to change the label name
label_map = {0: 'negative', 1: 'positive'}
# Excute the label change
df.replace({'label': label_map}, inplace=True)
# Write the updated dataset to a new file
df.to_csv(new_file)
# Replace the label name for both the training and test dataset. Then write the
# updated CSV dataset to the current folder.
replace_label(os.path.join(os.path.join(data_dir, 'train.tsv')), 'train.csv')
replace_label(os.path.join(os.path.join(data_dir, 'dev.tsv')), 'dev.csv')
Quickstart
There are five steps to train a text classification model:
Step 1. Choose a text classification model architecture.
Here we use the average word embedding model architecture, which will produce a small and fast model with decent accuracy.
spec = model_spec.get('average_word_vec')
Model Maker also supports other model architectures such as BERT. If you are interested to learn about other architecture, see the Choose a model architecture for Text Classifier section below.
Step 2. Load the training and test data, then preprocess them according to a specific model_spec
.
Model Maker can take input data in the CSV format. We will load the training and test dataset with the human-readable label name that were created earlier.
Each model architecture requires input data to be processed in a particular way. TextClassifierDataLoader
reads the requirement from model_spec
and automatically executes the necessary preprocessing.
train_data = TextClassifierDataLoader.from_csv(
filename='train.csv',
text_column='sentence',
label_column='label',
model_spec=spec,
is_training=True)
test_data = TextClassifierDataLoader.from_csv(
filename='dev.csv',
text_column='sentence',
label_column='label',
model_spec=spec,
is_training=False)
Step 3. Train the TensorFlow model with the training data.
The average word embedding model use batch_size = 32
by default. Therefore you will see that it takes 2104 steps to go through the 67,349 sentences in the training dataset. We will train the model for 10 epochs, which means going through the training dataset 10 times.
model = text_classifier.create(train_data, model_spec=spec, epochs=10)
Epoch 1/10 2104/2104 [==============================] - 7s 3ms/step - loss: 0.6830 - accuracy: 0.5595 Epoch 2/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.5781 - accuracy: 0.7091 Epoch 3/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.4452 - accuracy: 0.7967 Epoch 4/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3921 - accuracy: 0.8253 Epoch 5/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3665 - accuracy: 0.8409 Epoch 6/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3516 - accuracy: 0.8478 Epoch 7/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3397 - accuracy: 0.8542 Epoch 8/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3332 - accuracy: 0.8622 Epoch 9/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3261 - accuracy: 0.8644 Epoch 10/10 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3216 - accuracy: 0.8662
Step 4. Evaluate the model with the test data.
After training the text classification model using the sentences in the training dataset, we will use the remaining 872 sentences in the test dataset to evaluate how the model performs against new data it has never seen before.
As the default batch size is 32, it will take 28 steps to go through the 872 sentences in the test dataset.
loss, acc = model.evaluate(test_data)
28/28 [==============================] - 0s 2ms/step - loss: 0.5162 - accuracy: 0.8303
Step 5. Export as a TensorFlow Lite model.
Let's export the text classification that we have trained in the TensorFlow Lite format. We will specify which folder to export the model.
You may see a warning about vocab.txt
file does not exist in the metadata but they can be safely ignored.
model.export(export_dir='average_word_vec')
Finished populating metadata and associated file to the model: average_word_vec/model.tflite The metadata json file has been saved to: average_word_vec/model.json The associated file that has been been packed to the model is: ['vocab.txt', 'labels.txt'] /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_lite_support/metadata/python/metadata.py:344: UserWarning: File, 'vocab.txt', does not exsit in the metadata. But packing it to tflite model is still allowed. "tflite model is still allowed.".format(f))
You can download the TensorFlow Lite model file using the left sidebar of Colab. Go into the average_word_vec
folder as we specified in export_dir
parameter above, right-click on the model.tflite
file and choose Download
to download it to your local computer.
This model can be integrated into an Android or an iOS app using the NLClassifier API of the TensorFlow Lite Task Library.
See the TFLite Text Classification sample app for more details on how the model is used in a working app.
Note 1: Android Studio Model Binding does not support text classification yet so please use the TensorFlow Lite Task Library.
Note 2: There is a model.json
file in the same folder with the TFLite model. It contains the JSON representation of the metadata bundled inside the TensorFlow Lite model. Model metadata helps the TFLite Task Library know what the model does and how to pre-process/post-process data for the model. You don't need to download the model.json
file as it is only for informational purpose and its content is already inside the TFLite file.
Note 3: If you train a text classification model using MobileBERT or BERT-Base architecture, you will need to use BertNLClassifier API instead to integrate the trained model into a mobile app.
The following sections walk through the example step by step to show more details.
Choose a model architecture for Text Classifier
Each model_spec
object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports MobileBERT, averaging word embeddings and BERT-Base models.
Supported Model | Name of model_spec | Model Description | Model size |
---|---|---|---|
Averaging Word Embedding | 'average_word_vec' | Averaging text word embeddings with RELU activation. | <1MB |
MobileBERT | 'mobilebert_classifier' | 4.3x smaller and 5.5x faster than BERT-Base while achieving competitive results, suitable for on-device applications. | 25MB w/ quantization 100MB w/o quantization |
BERT-Base | 'bert_classifier' | Standard BERT model that is widely used in NLP tasks. | 300MB |
In the quick start, we have used the average word embedding model. Let's switch to MobileBERT to train a model with higher accuracy.
mb_spec = model_spec.get('mobilebert_classifier')
Load training data
You can upload your own dataset to work through this tutorial. Upload your dataset by using the left sidebar in Colab.
If you prefer not to upload your dataset to the cloud, you can also locally run the library by following the guide.
To keep it simple, we will reuse the SST-2 dataset downloaded earlier. Let's use the TestClassifierDataLoader.from_csv
method to load the data.
Please be noted that as we have changed the model architecture, we will need to reload the training and test dataset to apply the new preprocessing logic.
train_data = TextClassifierDataLoader.from_csv(
filename='train.csv',
text_column='sentence',
label_column='label',
model_spec=mb_spec,
is_training=True)
test_data = TextClassifierDataLoader.from_csv(
filename='dev.csv',
text_column='sentence',
label_column='label',
model_spec=mb_spec,
is_training=False)
The Model Maker library also supports the from_folder()
method to load data. It assumes that the text data of the same class are in the same subdirectory and that the subfolder name is the class name. Each text file contains one movie review sample. The class_labels
parameter is used to specify which the subfolders.
Train a TensorFlow Model
Train a text classification model using the training data.
model = text_classifier.create(train_data, model_spec=mb_spec, epochs=3)
Epoch 1/3 1403/1403 [==============================] - 309s 183ms/step - loss: 0.6354 - test_accuracy: 0.7444 Epoch 2/3 1403/1403 [==============================] - 244s 174ms/step - loss: 0.1467 - test_accuracy: 0.9465 Epoch 3/3 1403/1403 [==============================] - 245s 174ms/step - loss: 0.0833 - test_accuracy: 0.9727
Examine the detailed model structure.
model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_word_ids (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ input_mask (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ input_type_ids (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ hub_keras_layer_v1v2 (HubKerasL (None, 512) 24581888 input_word_ids[0][0] input_mask[0][0] input_type_ids[0][0] __________________________________________________________________________________________________ dropout_1 (Dropout) (None, 512) 0 hub_keras_layer_v1v2[0][0] __________________________________________________________________________________________________ output (Dense) (None, 2) 1026 dropout_1[0][0] ================================================================================================== Total params: 24,582,914 Trainable params: 24,582,914 Non-trainable params: 0 __________________________________________________________________________________________________
Evaluate the model
Evaluate the model that we have just trained using the test data and measure the loss and accuracy value.
loss, acc = model.evaluate(test_data)
28/28 [==============================] - 7s 44ms/step - loss: 0.3724 - test_accuracy: 0.9106
Quantize the model
In many on-device ML application, the model size is an important factor. Therefore, it is recommended that you apply quantize the model to make it smaller and potentially run faster. Model Maker automatically applies the recommended quantization scheme for each model architecture but you can customize the quantization config as below.
config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config.experimental_new_quantizer = True
Export as a TensorFlow Lite model
Convert the trained model to TensorFlow Lite model format with metadata so that you can later use in an on-device ML application. The label file and the vocab file are embedded in metadata. The default TFLite filename is model.tflite
.
model.export(export_dir='mobilebert/', quantization_config=config)
Finished populating metadata and associated file to the model: mobilebert/model.tflite The metadata json file has been saved to: mobilebert/model.json The associated file that has been been packed to the model is: ['vocab.txt', 'labels.txt']
The TensorFlow Lite model file can be integrated in a mobile app using the BertNLClassifier API in TensorFlow Lite Task Library. Please note that this is different from the NLClassifier
API used to integrate the text classification trained with the average word vector model architecture.
The export formats can be one or a list of the following:
ExportFormat.TFLITE
ExportFormat.LABEL
ExportFormat.VOCAB
ExportFormat.SAVED_MODEL
By default, it exports only the TensorFlow Lite model file containing the model metadata. You can also choose to export other files related to the model for better examination. For instance, exporting only the label file and vocab file as follows:
model.export(export_dir='mobilebert/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])
You can evaluate the TFLite model with evaluate_tflite
method to measure its accuracy. Converting the trained TensorFlow model to TFLite format and apply quantization can affect its accuracy so it is recommended to evaluate the TFLite model accuracy before deployment.
accuracy = model.evaluate_tflite('mobilebert/model.tflite', test_data)
print('TFLite model accuracy: ', accuracy)
TFLite model accuracy: {'accuracy': 0.9105504587155964}
Advanced Usage
The create
function is the driver function that the Model Maker library uses to create models. The model_spec
parameter defines the model specification. The AverageWordVecModelSpec
and BertClassifierModelSpec
classes are currently supported. The create
function comprises of the following steps:
- Creates the model for the text classifier according to
model_spec
. - Trains the classifier model. The default epochs and the default batch size are set by the
default_training_epochs
anddefault_batch_size
variables in themodel_spec
object.
This section covers advanced usage topics like adjusting the model and the training hyperparameters.
Customize the MobileBERT model hyperparameters
The model parameters you can adjust are:
seq_len
: Length of the sequence to feed into the model.initializer_range
: The standard deviation of thetruncated_normal_initializer
for initializing all weight matrices.trainable
: Boolean that specifies whether the pre-trained layer is trainable.
The training pipeline parameters you can adjust are:
model_dir
: The location of the model checkpoint files. If not set, a temporary directory will be used.dropout_rate
: The dropout rate.learning_rate
: The initial learning rate for the Adam optimizer.tpu
: TPU address to connect to.
For instance, you can set the seq_len=256
(default is 128). This allows the model to classify longer text.
new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256
Customize the average word embedding model hyperparameters
You can adjust the model infrastructure like the wordvec_dim
and the seq_len
variables in the AverageWordVecModelSpec
class.
For example, you can train the model with a larger value of wordvec_dim
. Note that you must construct a new model_spec
if you modify the model.
new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)
Get the preprocessed data.
new_train_data = TextClassifierDataLoader.from_csv(
filename='train.csv',
text_column='sentence',
label_column='label',
model_spec=new_model_spec,
is_training=True)
Train the new model.
model = text_classifier.create(new_train_data, model_spec=new_model_spec)
Epoch 1/3 2104/2104 [==============================] - 8s 4ms/step - loss: 0.6846 - accuracy: 0.5581 Epoch 2/3 2104/2104 [==============================] - 6s 3ms/step - loss: 0.5710 - accuracy: 0.7086 Epoch 3/3 2104/2104 [==============================] - 6s 3ms/step - loss: 0.4352 - accuracy: 0.8030
Tune the training hyperparameters
You can also tune the training hyperparameters like epochs
and batch_size
that affect the model accuracy. For instance,
epochs
: more epochs could achieve better accuracy, but may lead to overfitting.batch_size
: the number of samples to use in one training step.
For example, you can train with more epochs.
model = text_classifier.create(new_train_data, model_spec=new_model_spec, epochs=20)
Epoch 1/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.6847 - accuracy: 0.5557 Epoch 2/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.5682 - accuracy: 0.7170 Epoch 3/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.4318 - accuracy: 0.8025 Epoch 4/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3855 - accuracy: 0.8306 Epoch 5/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3628 - accuracy: 0.8459 Epoch 6/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3494 - accuracy: 0.8529 Epoch 7/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3380 - accuracy: 0.8582 Epoch 8/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3311 - accuracy: 0.8628 Epoch 9/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.3231 - accuracy: 0.8667 Epoch 10/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3184 - accuracy: 0.8688 Epoch 11/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3149 - accuracy: 0.8716 Epoch 12/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.3111 - accuracy: 0.8732 Epoch 13/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.3067 - accuracy: 0.8725 Epoch 14/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3028 - accuracy: 0.8753 Epoch 15/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.3014 - accuracy: 0.8759 Epoch 16/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.2984 - accuracy: 0.8776 Epoch 17/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.2968 - accuracy: 0.8793 Epoch 18/20 2104/2104 [==============================] - 7s 3ms/step - loss: 0.2936 - accuracy: 0.8803 Epoch 19/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.2925 - accuracy: 0.8801 Epoch 20/20 2104/2104 [==============================] - 6s 3ms/step - loss: 0.2913 - accuracy: 0.8828
Evaluate the newly retrained model with 20 training epochs.
new_test_data = TextClassifierDataLoader.from_csv(
filename='dev.csv',
text_column='sentence',
label_column='label',
model_spec=new_model_spec,
is_training=False)
loss, accuracy = model.evaluate(new_test_data)
28/28 [==============================] - 0s 2ms/step - loss: 0.4962 - accuracy: 0.8337
Change the Model Architecture
You can change the model by changing the model_spec
. The following shows how to change to BERT-Base model.
Change the model_spec
to BERT-Base model for the text classifier.
spec = model_spec.get('bert_classifier')
The remaining steps are the same.