TensorFlow Lite Metadata Writer API

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow Lite Model Metadata is a standard model description format. It contains rich semantics for general model information, inputs/outputs, and associated files, which makes the model self-descriptive and exchangeable.

Model Metadata is currently used in the following two primary use cases:

  1. Enable easy model inference using TensorFlow Lite Task Library and codegen tools. Model Metadata contains the mandatory information required during inference, such as label files in image classification, sampling rate of the audio input in audio classification, and tokenizer type to process input string in Natural Language models.

  2. Enable model creators to include documentation, such as description of model inputs/outputs or how to use the model. Model users can view these documentation via visualization tools such as Netron.

TensorFlow Lite Metadata Writer API provides an easy-to-use API to create Model Metadata for popular ML tasks supported by the TFLite Task Library. This notebook shows examples on how the metadata should be populated for the following tasks below:

Metadata writers for BERT natural language classifiers and BERT question answerers are coming soon.

If you want to add metadata for use cases that are not supported, please use the Flatbuffers Python API. See the tutorials here.

Prerequisites

Install the TensorFlow Lite Support Pypi package.

pip install tflite-support-nightly

Create Model Metadata for Task Library and Codegen

Image classifiers

See the image classifier model compatibility requirements for more details about the supported model format.

Step 1: Import the required packages.

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import writer_utils

Step 2: Download the example image classifier, mobilenet_v2_1.0_224.tflite, and the label file.

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt

Step 3: Create metadata writer and populate.

ImageClassifierWriter = image_classifier.MetadataWriter
_MODEL_PATH = "mobilenet_v2_1.0_224.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "mobilenet_labels.txt"
_SAVE_TO_PATH = "mobilenet_v2_1.0_224_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ImageClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Object detectors

See the object detector model compatibility requirements for more details about the supported model format.

Step 1: Import the required packages.

from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

Step 2: Download the example object detector, ssd_mobilenet_v1.tflite, and the label file.

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite -o ssd_mobilenet_v1.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt -o ssd_mobilenet_labels.txt

Step 3: Create metadata writer and populate.

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "ssd_mobilenet_labels.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Image segmenters

See the image segmenter model compatibility requirements for more details about the supported model format.

Step 1: Import the required packages.

from tflite_support.metadata_writers import image_segmenter
from tflite_support.metadata_writers import writer_utils

Step 2: Download the example image segmenter, deeplabv3.tflite, and the label file.

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite -o deeplabv3.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt -o deeplabv3_labels.txt

Step 3: Create metadata writer and populate.

ImageSegmenterWriter = image_segmenter.MetadataWriter
_MODEL_PATH = "deeplabv3.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "deeplabv3_labels.txt"
_SAVE_TO_PATH = "deeplabv3_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ImageSegmenterWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Natural language classifiers

See the natural language classifier model compatibility requirements for more details about the supported model format.

Step 1: Import the required packages.

from tflite_support.metadata_writers import nl_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

Step 2: Download the example natural language classifier, movie_review.tflite, the label file, and the vocab file.

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite -o movie_review.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt -o movie_review_labels.txt
curl -L https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/nl_classifier/vocab.txt -o movie_review_vocab.txt

Step 3: Create metadata writer and populate.

NLClassifierWriter = nl_classifier.MetadataWriter
_MODEL_PATH = "movie_review.tflite"
# Task Library expects label files and vocab files that are in the same formats
# as the ones below.
_LABEL_FILE = "movie_review_labels.txt"
_VOCAB_FILE = "movie_review_vocab.txt"
# NLClassifier supports tokenize input string using the regex tokenizer. See
# more details about how to set up RegexTokenizer below:
# https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py#L130
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_SAVE_TO_PATH = "moview_review_metadata.tflite"

# Create the metadata writer.
writer = nl_classifier.MetadataWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH),
    metadata_info.RegexTokenizerMd(_DELIM_REGEX_PATTERN, _VOCAB_FILE),
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Audio classifiers

See the audio classifier model compatibility requirements for more details about the supported model format.

Step 1: Import the required packages.

from tflite_support.metadata_writers import audio_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

Step 2: Download the example audio classifier, yamnet.tflite, and the label file.

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite -o yamnet.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt -o yamnet_labels.txt

Step 3: Create metadata writer and populate.

AudioClassifierWriter = audio_classifier.MetadataWriter
_MODEL_PATH = "yamnet.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "yamnet_labels.txt"
# Expected sampling rate of the input audio buffer.
_SAMPLE_RATE = 16000
# Expected number of channels of the input audio buffer. Note, Task library only
# support single channel so far.
_CHANNELS = 1
_SAVE_TO_PATH = "yamnet_metadata.tflite"

# Create the metadata writer.
writer = AudioClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Create Model Metadata with semantic information

You can fill in more descriptive information about the model and each tensor through the Metadata Writer API to help improve model understanding. It can be done through the 'create_from_metadata_info' method in each metadata writer. In general, you can fill in data through the parameters of 'create_from_metadata_info', i.e. general_md, input_md, and output_md. See the example below to create a rich Model Metadata for image classifers.

Step 1: Import the required packages.

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata_schema_py_generated as _metadata_fb

Step 2: Download the example image classifier, mobilenet_v2_1.0_224.tflite, and the label file.

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt

Step 3: Create model and tensor information.

model_buffer = writer_utils.load_file("mobilenet_v2_1.0_224.tflite")

# Create general model information.
general_md = metadata_info.GeneralMd(
    name="ImageClassifier",
    version="v1",
    description=("Identify the most prominent object in the image from a "
                 "known set of categories."),
    author="TensorFlow Lite",
    licenses="Apache License. Version 2.0")

# Create input tensor information.
input_md = metadata_info.InputImageTensorMd(
    name="input image",
    description=("Input image to be classified. The expected image is "
                 "128 x 128, with three channels (red, blue, and green) per "
                 "pixel. Each element in the tensor is a value between min and "
                 "max, where (per-channel) min is [0] and max is [255]."),
    norm_mean=[127.5],
    norm_std=[127.5],
    color_space_type=_metadata_fb.ColorSpaceType.RGB,
    tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])

# Create output tensor information.
output_md = metadata_info.ClassificationTensorMd(
    name="probability",
    description="Probabilities of the 1001 labels respectively.",
    label_files=[
        metadata_info.LabelFileMd(file_path="mobilenet_labels.txt",
                                  locale="en")
    ],
    tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])

Step 4: Create metadata writer and populate.

ImageClassifierWriter = image_classifier.MetadataWriter
# Create the metadata writer.
writer = ImageClassifierWriter.create_from_metadata_info(
    model_buffer, general_md, input_md, output_md)

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Read the metadata populated to your model.

You can display the metadata and associated files in a TFLite model through the following code:

from tflite_support import metadata

displayer = metadata.MetadataDisplayer.with_model_file("mobilenet_v2_1.0_224_metadata.tflite")
print("Metadata populated:")
print(displayer.get_metadata_json())

print("Associated file(s) populated:")
for file_name in displayer.get_packed_associated_file_list():
  print("file name: ", file_name)
  print("file content:")
  print(displayer.get_associated_file_buffer(file_name))