Text classification with TensorFlow Lite Model Maker

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.

This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used text classification model to classify movie reviews on a mobile device. The text classification model classifies text into predefined categories.The inputs should be preprocessed text and the outputs are the probabilities of the categories. The dataset used in this tutorial are positive and negative movie reviews.

Prerequisites

Install the required packages

To run this example, install the required packages, including the Model Maker package from the GitHub repo.

pip install tflite-model-maker
Collecting tflite-model-maker
[?25l  Downloading https://files.pythonhosted.org/packages/13/bc/4c23b9cb9ef612a1f48bac5543bd531665de5eab8f8231111aac067f8c30/tflite_model_maker-0.1.2-py3-none-any.whl (104kB)
[K     |████████████████████████████████| 112kB 8.2MB/s 
[?25hRequirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)
Collecting fire
[?25l  Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)
[K     |████████████████████████████████| 81kB 7.7MB/s 
[?25hCollecting flatbuffers==1.12
  Downloading https://files.pythonhosted.org/packages/eb/26/712e578c5f14e26ae3314c39a1bdc4eb2ec2f4ddc89b708cf8e0a0d20423/flatbuffers-1.12-py2.py3-none-any.whl
Collecting tf-models-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/d3/e9/c4e5a451c268a5a75a27949562364f6086f6bb33b226a065a8beceefa9ba/tf_models_nightly-2.3.0.dev20200914-py2.py3-none-any.whl (993kB)
[K     |████████████████████████████████| 1.0MB 17.6MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 31.6MB/s 
[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)
Collecting tf-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)
[K     |████████████████████████████████| 390.2MB 46kB/s 
[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)
Requirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)
Collecting tflite-support==0.1.0rc3.dev2
[?25l  Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 50.0MB/s 
[?25hRequirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (3.12.4)
Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (1.15.0)
Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->tflite-model-maker) (1.1.0)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)
Collecting tensorflow-model-optimization>=0.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |████████████████████████████████| 174kB 57.7MB/s 
[?25hCollecting tf-slim>=1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)
[K     |████████████████████████████████| 358kB 54.9MB/s 
[?25hRequirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)
Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)
Collecting seqeval
  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)
Collecting pyyaml>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 55.1MB/s 
[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)
Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)
[K     |████████████████████████████████| 36.6MB 88kB/s 
[?25hRequirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)
Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)
Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)
Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)
Collecting py-cpuinfo>=3.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)
[K     |████████████████████████████████| 102kB 11.1MB/s 
[?25hRequirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)
Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)
Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)
Requirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)
Collecting tb-nightly<3.0.0a0,>=2.4.0a0
[?25l  Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)
[K     |████████████████████████████████| 10.1MB 46.1MB/s 
[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)
Collecting tf-estimator-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)
[K     |████████████████████████████████| 460kB 51.7MB/s 
[?25hRequirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)
Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.12.1)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)
Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)
Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)
Collecting pybind11>=2.4
[?25l  Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)
[K     |████████████████████████████████| 296kB 55.2MB/s 
[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow-hub>=0.8.0->tflite-model-maker) (50.3.0)
Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)
Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)
Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)
Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2.8.1)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)
Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.17.4)
Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)
Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (1.17.2)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)
Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (4.1.1)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Building wheels for collected packages: fire, seqeval, pyyaml, py-cpuinfo
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=9eaa2d36e17621d136f8ab1707a5a4e8994c53d5076a9edde21aab7696ba3e09
  Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7423 sha256=1ce4604da2a395f0304db708bf2e2c1831033ed8b1f7c23927d70ed9ed7b7110
  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44619 sha256=d51b6ef3e90de74d0c1cee8f7aafe0a6d8674348c8437cd89ad5c60a6c3dc726
  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd
  Building wheel for py-cpuinfo (setup.py) ... [?25l[?25hdone
  Created wheel for py-cpuinfo: filename=py_cpuinfo-7.0.0-cp36-none-any.whl size=20071 sha256=096439bff3cb3e4cc21b86472c629017fd9c972d6e2ed231e1a91d2096fc687d
  Stored in directory: /root/.cache/pip/wheels/f1/93/7b/127daf0c3a5a49feb2fecd468d508067c733fba5192f726ad1
Successfully built fire seqeval pyyaml py-cpuinfo
Installing collected packages: fire, flatbuffers, tensorflow-model-optimization, tf-slim, seqeval, pyyaml, opencv-python-headless, sentencepiece, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly, pybind11, tflite-support, tflite-model-maker
  Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2

Import the required packages.

import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker import TextClassifierDataLoader

Get the data path

Download the dataset for this tutorial.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8
7446528/7439277 [==============================] - 0s 0us/step

You can also upload your own dataset to work through this tutorial. Upload your dataset by using the left sidebar in Colab.

Upload File

If you prefer not to upload your dataset to the cloud, you can also locally run the library by following the guide.

End-to-End Workflow

This workflow consists of five steps as outlined below:

Step 1. Choose a model specification that represents a text classification model.

This tutorial uses MobileBERT as an example.

spec = model_spec.get('mobilebert_classifier')

Step 2. Load train and test data specific to an on-device ML app and preprocess the data according to a specific model_spec.

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

Step 3. Customize the TensorFlow model.

model = text_classifier.create(train_data, model_spec=spec)

Step 4. Evaluate the model.

loss, acc = model.evaluate(test_data)

Step 5. Export as a TensorFlow Lite model with metadata.

Since MobileBERT is too big for on-device applications, use dynamic range quantization on the model to compress it by almost 4x with minimal performance degradation.

config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config._experimental_new_quantizer = True
model.export(export_dir='mobilebert/', quantization_config=config)

You can also download the model using the left sidebar in Colab.

After executing the 5 steps above, you can further use the TensorFlow Lite model file in on-device applications using BertNLClassifier API in TensorFlow Lite Task Library.

The following sections walk through the example step by step to show more detail.

Choose a model_spec that Represents a Model for Text Classifier

Each model_spec object represents a specific model for the text classifier. TensorFlow Lite Model Maker currently supports MobileBERT, averaging word embeddings and BERT-Base models.

Supported Model Name of model_spec Model Description
MobileBERT 'mobilebert_classifier' 4.3x smaller and 5.5x faster than BERT-Base while achieving competitive results, suitable for on-device applications.
BERT-Base 'bert_classifier' Standard BERT model that is widely used in NLP tasks.
averaging word embedding 'average_word_vec' Averaging text word embeddings with RELU activation.

This tutorial uses a smaller model, average_word_vec that you can retrain multiple times to demonstrate the process.

spec = model_spec.get('average_word_vec')

Load Input Data Specific to an On-device ML App

The SST-2 (Stanford Sentiment Treebank) is one of the tasks in the GLUE benchmark. It contains 67,349 movie reviews for training and 872 movie reviews for validation. The dataset has two classes: positive and negative movie reviews.

Download the archived version of the dataset and extract it.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')

The SST-2 dataset has train.tsv for training and dev.tsv for validation. The files have the following format:

sentence label
it 's a charming and often affecting journey . 1
unflinchingly bleak and desperate 0

A positive review is labeled 1 and a negative review is labeled 0.

Use the TestClassifierDataLoader.from_csv method to load the data.

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

The Model Maker library also supports the from_folder() method to load data. It assumes that the text data of the same class are in the same subdirectory and that the subfolder name is the class name. Each text file contains one movie review sample. The class_labels parameter is used to specify which the subfolders.

Customize the TensorFlow Model

Create a custom text classifier model based on the loaded data.

model = text_classifier.create(train_data, model_spec=spec, epochs=10)

Examine the detailed model structure.

model.summary()

Evaluate the Customized Model

Evaluate the model with the test data and get its loss and accuracy.

loss, acc = model.evaluate(test_data)

Export as a TensorFlow Lite Model

Convert the existing model to TensorFlow Lite model format with metadata that you can later use in an on-device ML application. The label file and the vocab file are embedded in metadata. The default TFLite filename is model.tflite.

model.export(export_dir='average_word_vec/')

The TensorFlow Lite model file can be used in the text classification reference app using NLClassifier API in TensorFlow Lite Task Library.

The allowed export formats can be one or a list of the following:

  • ExportFormat.TFLITE
  • ExportFormat.LABEL
  • ExportFormat.VOCAB
  • ExportFormat.SAVED_MODEL

By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the label file and vocab file as follows:

model.export(export_dir='average_word_vec/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])

You can evaluate the tflite model with evaluate_tflite method to get its accuracy.

accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)

Advanced Usage

The create function is the driver function that the Model Maker library uses to create models. The model_spec parameter defines the model specification. The AverageWordVecModelSpec and BertClassifierModelSpec classes are currently supported. The create function comprises of the following steps:

  1. Creates the model for the text classifier according to model_spec.
  2. Trains the classifier model. The default epochs and the default batch size are set by the default_training_epochs and default_batch_size variables in the model_spec object.

This section covers advanced usage topics like adjusting the model and the training hyperparameters.

Adjust the model

You can adjust the model infrastructure like the wordvec_dim and the seq_len variables in the AverageWordVecModelSpec class.

For example, you can train the model with a larger value of wordvec_dim. Note that you must construct a new model_spec if you modify the model.

new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)

Get the preprocessed data.

new_train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      delimiter='\t',
      is_training=True)

Train the new model.

model = text_classifier.create(new_train_data, model_spec=new_model_spec)

You can also adjust the MobileBERT model.

The model parameters you can adjust are:

  • seq_len: Length of the sequence to feed into the model.
  • initializer_range: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • trainable: Boolean that specifies whether the pre-trained layer is trainable.

The training pipeline parameters you can adjust are:

  • model_dir: The location of the model checkpoint files. If not set, a temporary directory will be used.
  • dropout_rate: The dropout rate.
  • learning_rate: The initial learning rate for the Adam optimizer.
  • tpu: TPU address to connect to.

For instance, you can set the seq_len=256 (default is 128). This allows the model to classify longer text.

new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256

Tune the training hyperparameters

You can also tune the training hyperparameters like epochs and batch_size that affect the model accuracy. For instance,

  • epochs: more epochs could achieve better accuracy, but may lead to overfitting.
  • batch_size: the number of samples to use in one training step.

For example, you can train with more epochs.

model = text_classifier.create(train_data, model_spec=spec, epochs=20)

Evaluate the newly retrained model with 20 training epochs.

loss, accuracy = model.evaluate(test_data)

Change the Model Architecture

You can change the model by changing the model_spec. The following shows how to change to BERT-Base model.

Change the model_spec to BERT-Base model for the text classifier.

spec = model_spec.get('bert_classifier')

The remaining steps are the same.