Converting TensorFlow Text operators to TensorFlow Lite

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Overview

Machine learning models are frequently deployed using TensorFlow Lite to mobile, embedded, and IoT devices to improve data privacy and lower response times. These models often require support for text processing operations. TensorFlow Text version 2.7 and higher provides improved performance, reduced binary sizes, and operations specifically optimized for use in these environments.

Text operators

The following TensorFlow Text classes and functions can be used from within a TensorFlow Lite model.

Model Example

pip install -U "tensorflow-text==2.11.*"
from absl import app
import numpy as np
import tensorflow as tf
import tensorflow_text as tf_text

from tensorflow.lite.python import interpreter
2024-02-25 12:36:11.657873: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-02-25 12:36:12.316829: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-02-25 12:36:12.316920: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2024-02-25 12:36:12.316929: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

The following code example shows the conversion process and interpretation in Python using a simple test model. Note that the output of a model cannot be a tf.RaggedTensor object when you are using TensorFlow Lite. However, you can return the components of a tf.RaggedTensor object or convert it using its to_tensor function. See the RaggedTensor guide for more details.

class TokenizerModel(tf.keras.Model):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.tokenizer = tf_text.WhitespaceTokenizer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='input')
  ])
  def call(self, input_tensor):
    return { 'tokens': self.tokenizer.tokenize(input_tensor).flat_values }
# Test input data.
input_data = np.array(['Some minds are better kept apart'])

# Define a Keras model.
model = TokenizerModel()

# Perform TensorFlow Text inference.
tf_result = model(tf.constant(input_data))
print('TensorFlow result = ', tf_result['tokens'])
2024-02-25 12:36:14.062746: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-02-25 12:36:14.062862: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2024-02-25 12:36:14.062925: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2024-02-25 12:36:14.062987: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2024-02-25 12:36:14.119250: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2024-02-25 12:36:14.119470: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
TensorFlow result =  tf.Tensor([b'Some' b'minds' b'are' b'better' b'kept' b'apart'], shape=(6,), dtype=string)

Convert the TensorFlow model to TensorFlow Lite

When converting a TensorFlow model with TensorFlow Text operators to TensorFlow Lite, you need to indicate to the TFLiteConverter that there are custom operators using the allow_custom_ops attribute as in the example below. You can then run the model conversion as you normally would. Review the TensorFlow Lite converter documentation for a detailed guide on the basics of model conversion.

# Convert to TensorFlow Lite.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.allow_custom_ops = True
tflite_model = converter.convert()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpuvu454nf/assets
2024-02-25 12:36:16.735448: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2024-02-25 12:36:16.735488: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2024-02-25 12:36:16.933848: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2057] The following operation(s) need TFLite custom op implementation(s):
Custom ops: TFText>WhitespaceTokenizeWithOffsetsV2
Details:
    tf.TFText>WhitespaceTokenizeWithOffsetsV2(tensor<?x!tf_type.string>, tensor<!tf_type.string>) -> (tensor<?x!tf_type.string>, tensor<?xi64>, tensor<?xi32>, tensor<?xi32>) : {device = ""}
See instructions: https://www.tensorflow.org/lite/guide/ops_custom

Inference

For the TensorFlow Lite interpreter to properly read your model containing TensorFlow Text operators, you must configure it to use these custom operators, and provide registration methods for them. Use tf_text.tflite_registrar.SELECT_TFTEXT_OPS to provide the full suite of registration functions for the supported TensorFlow Text operators to InterpreterWithCustomOps.

Note, that while the example below shows inference in Python, the steps are similar in other languages with some minor API translations, and the necessity to build the tflite_registrar into your binary. See TensorFlow Lite Inference for more details.

# Perform TensorFlow Lite inference.
interp = interpreter.InterpreterWithCustomOps(
    model_content=tflite_model,
    custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS)
interp.get_signature_list()
{'serving_default': {'inputs': ['input'], 'outputs': ['tokens']} }

Next, the TensorFlow Lite interpreter is invoked with the input, providing a result which matches the TensorFlow result from above.

tokenize = interp.get_signature_runner('serving_default')
output = tokenize(input=input_data)
print('TensorFlow Lite result = ', output['tokens'])
TensorFlow Lite result =  [b'Some' b'minds' b'are' b'better' b'kept' b'apart']