Object Detection with TensorFlow Lite Model Maker

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In this colab notebook, you'll learn how to use the TensorFlow Lite Model Maker library to train a custom object detection model capable of detecting salads within images on a mobile device.

The Model Maker library uses transfer learning to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data required and will shorten the training time.

You'll use the publicly available Salads dataset, which was created from the Open Images Dataset V4.

Each image in the dataset contains objects labeled as one of the following classes:

  • Baked Good
  • Cheese
  • Salad
  • Seafood
  • Tomato

The dataset contains the bounding-boxes specifying where each object locates, together with the object's label.

Here is an example image from the dataset:


Prerequisites

Install the required packages

Start by installing the required packages, including the Model Maker package from the GitHub repo and the pycocotools library you'll use for evaluation.

sudo apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0

Import the required packages.

import numpy as np
import os

from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

Prepare the dataset

Here you'll use the same dataset as the AutoML quickstart.

The Salads dataset is available at: gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv.

It contains 175 images for training, 25 images for validation, and 25 images for testing. The dataset has five classes: Salad, Seafood, Tomato, Baked goods, Cheese.


The dataset is provided in CSV format:

TRAINING,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Salad,0.0,0.0954,,,0.977,0.957,,
VALIDATION,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Seafood,0.0154,0.1538,,,1.0,0.802,,
TEST,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Tomato,0.0,0.655,,,0.231,0.839,,
  • Each row corresponds to an object localized inside a larger image, with each object specifically designated as test, train, or validation data. You'll learn more about what that means in a later stage in this notebook.
  • The three lines included here indicate three distinct objects located inside the same image available at gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg.
  • Each row has a different label: Salad, Seafood, Tomato, etc.
  • Bounding boxes are specified for each image using the top left and bottom right vertices.

Here is a visualization of these three lines:


If you want to know more about how to prepare your own CSV file and the minimum requirements for creating a valid dataset, see the Preparing your training data guide for more details.

If you are new to Google Cloud, you may wonder what the gs:// URL means. They are URLs of files stored on Google Cloud Storage (GCS). If you make your files on GCS public or authenticate your client, Model Maker can read those files similarly to your local files.

However, you don't need to keep your images on Google Cloud to use Model Maker. You can use a local path in your CSV file and Model Maker will just work.

Quickstart

There are six steps to training an object detection model:

Step 1. Choose an object detection model architecture.

This tutorial uses the EfficientDet-Lite0 model. EfficientDet-Lite[0-4] are a family of mobile/IoT-friendly object detection models derived from the EfficientDet architecture.

Here is the performance of each EfficientDet-Lite models compared to each others.

Model architecture Size(MB)* Latency(ms)** Average Precision***
EfficientDet-Lite0 4.4 37 25.69%
EfficientDet-Lite1 5.8 49 30.55%
EfficientDet-Lite2 7.2 69 33.97%
EfficientDet-Lite3 11.4 116 37.70%
EfficientDet-Lite4 19.9 260 41.96%

* Size of the integer quantized models.
** Latency measured on Pixel 4 using 4 threads on CPU.
*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.

spec = model_spec.get('efficientdet_lite0')

Step 2. Load the dataset.

Model Maker will take input data in the CSV format. Use the object_detector.DataLoader.from_csv method to load the dataset and split them into the training, validation and test images.

  • Training images: These images are used to train the object detection model to recognize salad ingredients.
  • Validation images: These are images that the model didn't see during the training process. You'll use them to decide when you should stop the training, to avoid overfitting.
  • Test images: These images are used to evaluate the final model performance.

You can load the CSV file directly from Google Cloud Storage, but you don't need to keep your images on Google Cloud to use Model Maker. You can specify a local CSV file on your computer, and Model Maker will work just fine.

train_data, validation_data, test_data = object_detector.DataLoader.from_csv('gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv')

Step 3. Train the TensorFlow model with the training data.

  • The EfficientDet-Lite0 model uses epochs = 50 by default, which means it will go through the training dataset 50 times. You can look at the validation accuracy during training and stop early to avoid overfitting.
  • Set batch_size = 8 here so you will see that it takes 21 steps to go through the 175 images in the training dataset.
  • Set train_whole_model=True to fine-tune the whole model instead of just training the head layer to improve accuracy. The trade-off is that it may take longer to train the model.
model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)

Step 4. Evaluate the model with the test data.

After training the object detection model using the images in the training dataset, use the remaining 25 images in the test dataset to evaluate how the model performs against new data it has never seen before.

As the default batch size is 64, it will take 1 step to go through the 25 images in the test dataset.

The evaluation metrics are same as COCO.

model.evaluate(test_data)

Step 5. Export as a TensorFlow Lite model.

Export the trained object detection model to the TensorFlow Lite format by specifying which folder you want to export the quantized model to. The default post-training quantization technique is full integer quantization.

model.export(export_dir='.')

Step 6. Evaluate the TensorFlow Lite model.

Several factors can affect the model accuracy when exporting to TFLite:

  • Quantization helps shrinking the model size by 4 times at the expense of some accuracy drop.
  • The original TensorFlow model uses per-class non-max supression (NMS) for post-processing, while the TFLite model uses global NMS that's much faster but less accurate. Keras outputs maximum 100 detections while tflite outputs maximum 25 detections.

Therefore you'll have to evaluate the exported TFLite model and compare its accuracy with the original TensorFlow model.

model.evaluate_tflite('model.tflite', test_data)

You can download the TensorFlow Lite model file using the left sidebar of Colab. Right-click on the model.tflite file and choose Download to download it to your local computer.

This model can be integrated into an Android or an iOS app using the ObjectDetector API of the TensorFlow Lite Task Library.

See the TFLite Object Detection sample app for more details on how the model is used in an working app.

(Optional) Test the TFLite model on your image

You can test the trained TFLite model using images from the internet.

  • Replace the INPUT_IMAGE_URL below with your desired input image.
  • Adjust the DETECTION_THRESHOLD to change the sensitivity of the model. A lower threshold means the model will pickup more objects but there will also be more false detection. Meanwhile, a higher threshold means the model will only pickup objects that it has confidently detected.

Although it requires some of boilerplate code to run the model in Python at this moment, integrating the model into a mobile app only requires a few lines of code.

Load the trained TFLite model and define some visualization functions

Run object detection and show the detection results

(Optional) Compile For the Edge TPU

Now that you have a quantized EfficientDet Lite model, it is possible to compile and deploy to a Coral EdgeTPU.

Step 1. Install the EdgeTPU Compiler

 curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -

 echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list

 sudo apt-get update

 sudo apt-get install edgetpu-compiler

Step 2. Select number of Edge TPUs, Compile

The EdgeTPU has 8MB of SRAM for caching model parameters (more info). This means that for models that are larger than 8MB, inference time will be increased in order to transfer over model parameters. One way to avoid this is Model Pipelining - splitting the model into segments that can have a dedicated EdgeTPU. This can significantly improve latency.

The below table can be used as a reference for the number of Edge TPUs to use - the larger models will not compile for a single TPU as the intermediate tensors can't fit in on-chip memory.

Model architecture Minimum TPUs Recommended TPUs
EfficientDet-Lite0 1 1
EfficientDet-Lite1 1 1
EfficientDet-Lite2 1 2
EfficientDet-Lite3 2 2
EfficientDet-Lite4 2 3

Step 3. Download, Run Model

With the model(s) compiled, they can now be run on EdgeTPU(s) for object detection. First, download the compiled TensorFlow Lite model file using the left sidebar of Colab. Right-click on the model_edgetpu.tflite file and choose Download to download it to your local computer.

Now you can run the model in your preferred manner. Examples of detection include:

Advanced Usage

This section covers advanced usage topics like adjusting the model and the training hyperparameters.

Load the dataset

Load your own data

You can upload your own dataset to work through this tutorial. Upload your dataset by using the left sidebar in Colab.

Upload File

If you prefer not to upload your dataset to the cloud, you can also locally run the library by following the guide.

Load your data with a different data format

The Model Maker library also supports the object_detector.DataLoader.from_pascal_voc method to load data with PASCAL VOC format. makesense.ai and LabelImg are the tools that can annotate the image and save annotations as XML files in PASCAL VOC data format:

object_detector.DataLoader.from_pascal_voc(image_dir, annotations_dir, label_map={1: "person", 2: "notperson"})

Customize the EfficientDet model hyperparameters

The model and training pipeline parameters you can adjust are:

  • model_dir: The location to save the model checkpoint files. If not set, a temporary directory will be used.
  • steps_per_execution: Number of steps per training execution.
  • moving_average_decay: Float. The decay to use for maintaining moving averages of the trained parameters.
  • var_freeze_expr: The regular expression to map the prefix name of variables to be frozen which means remaining the same during training. More specific, use re.match(var_freeze_expr, variable_name) in the codebase to map the variables to be frozen.
  • tflite_max_detections: integer, 25 by default. The max number of output detections in the TFLite model.
  • strategy: A string specifying which distribution strategy to use. Accepted values are 'tpu', 'gpus', None. tpu' means to use TPUStrategy. 'gpus' mean to use MirroredStrategy for multi-gpus. If None, use TF default with OneDeviceStrategy.
  • tpu: The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
  • use_xla: Use XLA even if strategy is not tpu. If strategy is tpu, always use XLA, and this flag has no effect.
  • profile: Enable profile mode.
  • debug: Enable debug mode.

Other parameters that can be adjusted is shown in hparams_config.py.

For instance, you can set the var_freeze_expr='efficientnet' which freezes the variables with name prefix efficientnet (default is '(efficientnet|fpn_cells|resample_p6)'). This allows the model to freeze untrainable variables and keep their value the same through training.

spec = model_spec.get('efficientdet_lite0')
spec.config.var_freeze_expr = 'efficientnet'

Change the Model Architecture

You can change the model architecture by changing the model_spec. For instance, change the model_spec to the EfficientDet-Lite4 model.

spec = model_spec.get('efficientdet_lite4')

Tune the training hyperparameters

The create function is the driver function that the Model Maker library uses to create models. The model_spec parameter defines the model specification. The object_detector.EfficientDetSpec class is currently supported. The create function comprises of the following steps:

  1. Creates the model for the object detection according to model_spec.
  2. Trains the model. The default epochs and the default batch size are set by the epochs and batch_size variables in the model_spec object. You can also tune the training hyperparameters like epochs and batch_size that affect the model accuracy. For instance,
  • epochs: Integer, 50 by default. More epochs could achieve better accuracy, but may lead to overfitting.
  • batch_size: Integer, 64 by default. The number of samples to use in one training step.
  • train_whole_model: Boolean, False by default. If true, train the whole model. Otherwise, only train the layers that do not match var_freeze_expr.

For example, you can train with less epochs and only the head layer. You can increase the number of epochs for better results.

model = object_detector.create(train_data, model_spec=spec, epochs=10, validation_data=validation_data)

Export to different formats

The export formats can be one or a list of the following:

By default, it exports only the TensorFlow Lite model file containing the model metadata so that you can later use in an on-device ML application. The label file is embedded in metadata.

In many on-device ML application, the model size is an important factor. Therefore, it is recommended that you quantize the model to make it smaller and potentially run faster. As for EfficientDet-Lite models, full integer quantization is used to quantize the model by default. Please refer to Post-training quantization for more detail.

model.export(export_dir='.')

You can also choose to export other files related to the model for better examination. For instance, exporting both the saved model and the label file as follows:

model.export(export_dir='.', export_format=[ExportFormat.SAVED_MODEL, ExportFormat.LABEL])

Customize Post-training quantization on the TensorFlow Lite model

Post-training quantization is a conversion technique that can reduce model size and inference latency, while also improving CPU and hardware accelerator inference speed, with a little degradation in model accuracy. Thus, it's widely used to optimize the model.

Model Maker library applies a default post-training quantization technique when exporting the model. If you want to customize post-training quantization, Model Maker supports multiple post-training quantization options using QuantizationConfig as well. Let's take float16 quantization as an instance. First, define the quantization config.

config = QuantizationConfig.for_float16()

Then we export the TensorFlow Lite model with such configuration.

model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)

Read more

You can read our object detection example to learn technical details. For more information, please refer to: