![]() |
![]() |
![]() |
![]() |
TensorFlow Lite supports converting TensorFlow model's input/output specifications to TensorFlow Lite models. The input/output specifications are called "signatures". Signatures can be specified when building a SavedModel or creating concrete functions.
Signatures in TensorFlow Lite provide the following features:
- They specify inputs and outputs of the converted TensorFlow Lite model by respecting the TensorFlow model's signatures.
- Allow a single TensorFlow Lite model to support multiple entry points.
The signature is composed of three pieces:
- Inputs: Map for inputs from input name in the signature to an input tensor.
- Outputs: Map for output mapping from output name in signature to an output tensor.
- Signature Key: Name that identifies an entry point of the graph.
Setup
import tensorflow as tf
2023-09-02 11:15:21.154010: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-09-02 11:15:21.154055: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-09-02 11:15:21.154094: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Example model
Let's say we have two tasks, e.g., encoding and decoding, as a TensorFlow model:
class Model(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def encode(self, x):
result = tf.strings.as_string(x)
return {
"encoded_result": result
}
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def decode(self, x):
result = tf.strings.to_number(x)
return {
"decoded_result": result
}
In the signature wise, the above TensorFlow model can be summarized as follows:
Signature
- Key: encode
- Inputs: {"x"}
- Output: {"encoded_result"}
Signature
- Key: decode
- Inputs: {"x"}
- Output: {"decoded_result"}
Convert a model with Signatures
TensorFlow Lite converter APIs will bring the above signature information into the converted TensorFlow Lite model.
This conversion functionality is available on all the converter APIs starting from TensorFlow version 2.7.0. See example usages.
From Saved Model
model = Model()
# Save the model
SAVED_MODEL_PATH = 'content/saved_models/coding'
tf.saved_model.save(
model, SAVED_MODEL_PATH,
signatures={
'encode': model.encode.get_concrete_function(),
'decode': model.decode.get_concrete_function()
})
# Convert the saved model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: content/saved_models/coding/assets {'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} } 2023-09-02 11:15:25.547952: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:15:25.547984: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-09-02 11:15:25.584283: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2178] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s): Flex ops: FlexAsString, FlexStringToNumber Details: tf.AsString(tensor<?xf32>) -> (tensor<?x!tf_type.string>) : {device = "", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64} tf.StringToNumber(tensor<?x!tf_type.string>) -> (tensor<?xf32>) : {device = "", out_type = f32} See instructions: https://www.tensorflow.org/lite/guide/ops_select INFO: Created TensorFlow Lite delegate for select TF ops. INFO: TfLiteFlexDelegate delegate: 1 nodes delegated out of 1 nodes with 1 partitions.
From Keras Model
# Generate a Keras model.
keras_model = tf.keras.Sequential(
[
tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),
tf.keras.layers.Dense(1, activation='relu', name='output'),
]
)
# Convert the keras model using TFLiteConverter.
# Keras model converter API uses the default signature automatically.
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp_1tf3ntv/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp_1tf3ntv/assets {'serving_default': {'inputs': ['x_input'], 'outputs': ['output']} } 2023-09-02 11:15:26.085482: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:15:26.085512: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
From Concrete Functions
model = Model()
# Convert the concrete functions using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[model.encode.get_concrete_function(),
model.decode.get_concrete_function()], model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu1nd0d05/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu1nd0d05/assets {'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} } 2023-09-02 11:15:26.259528: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:15:26.259557: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-09-02 11:15:26.289823: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2178] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s): Flex ops: FlexAsString, FlexStringToNumber Details: tf.AsString(tensor<?xf32>) -> (tensor<?x!tf_type.string>) : {device = "", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64} tf.StringToNumber(tensor<?x!tf_type.string>) -> (tensor<?xf32>) : {device = "", out_type = f32} See instructions: https://www.tensorflow.org/lite/guide/ops_select
Run Signatures
TensorFlow inference APIs support the signature-based executions:
- Accessing the input/output tensors through the names of the inputs and outputs, specified by the signature.
- Running each entry point of the graph separately, identified by the signature key.
- Support for the SavedModel's initialization procedure.
Java, C++ and Python language bindings are currently available. See example the below sections.
Java
try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
// Run encoding signature.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("encoded_result", encoded_result);
interpreter.runSignature(inputs, outputs, "encode");
// Run decoding signature.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", encoded_result);
Map<String, Object> outputs = new HashMap<>();
outputs.put("decoded_result", decoded_result);
interpreter.runSignature(inputs, outputs, "decode");
}
C++
SignatureRunner* encode_runner =
interpreter->GetSignatureRunner("encode");
encode_runner->ResizeInputTensor("x", {100});
encode_runner->AllocateTensors();
TfLiteTensor* input_tensor = encode_runner->input_tensor("x");
float* input = GetTensorData<float>(input_tensor);
// Fill `input`.
encode_runner->Invoke();
const TfLiteTensor* output_tensor = encode_runner->output_tensor(
"encoded_result");
float* output = GetTensorData<float>(output_tensor);
// Access `output`.
Python
# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# Print the signatures from the converted model
signatures = interpreter.get_signature_list()
print('Signature:', signatures)
# encode and decode are callable with input as arguments.
encode = interpreter.get_signature_runner('encode')
decode = interpreter.get_signature_runner('decode')
# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.
input = tf.constant([1, 2, 3], dtype=tf.float32)
print('Input:', input)
encoded = encode(x=input)
print('Encoded result:', encoded)
decoded = decode(x=encoded['encoded_result'])
print('Decoded result:', decoded)
Signature: {'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} } Input: tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32) Encoded result: {'encoded_result': array([b'1.000000', b'2.000000', b'3.000000'], dtype=object)} Decoded result: {'decoded_result': array([1., 2., 3.], dtype=float32)}
Known limitations
- As TFLite interpreter does not gurantee thread safety, the signature runners from the same interpreter won't be executed concurrently.
- Support for C/iOS/Swift is not available yet.
Updates
- Version 2.7
- The multiple signature feature is implemented.
- All the converter APIs from version two generate signature-enabled TensorFlow Lite models.
- Version 2.5
- Signature feature is available through the
from_saved_model
converter API.
- Signature feature is available through the