TensorFlow Lite의 서명

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

TensorFlow Lite는 TensorFlow 모델의 입력/출력 사양을 TensorFlow Lite 모델로 변환하는 것을 지원합니다. 입/출력 사양을 "서명"이라고 합니다. SavedModel을 구축하거나 구체적인 기능을 생성할 때 서명을 지정할 수 있습니다.

TensorFlow Lite의 서명은 다음 기능을 제공합니다.

  • TensorFlow 모델의 서명을 적용하여 변환된 TensorFlow Lite 모델의 입력 및 출력을 지정합니다.
  • 단일 TensorFlow Lite 모델이 여러 진입점을 지원할 수 있습니다.

서명은 세 부분으로 구성됩니다.

  • 입력: 서명의 입력 이름에서 입력 텐서로의 입력에 대한 매핑입니다.
  • 출력: 서명의 출력 이름에서 출력 텐서로의 출력 매핑을 위한 맵입니다.
  • 서명 키: 그래프의 진입점을 식별하는 이름입니다.

설정

import tensorflow as tf
2022-12-15 01:06:27.563712: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 01:06:27.563811: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 01:06:27.563821: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

예제 모델

TensorFlow 모델로 인코딩 및 디코딩과 같은 두 가지 작업이 있다고 가정해 보겠습니다.

class Model(tf.Module):

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
  def encode(self, x):
    result = tf.strings.as_string(x)
    return {
         "encoded_result": result
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
  def decode(self, x):
    result = tf.strings.to_number(x)
    return {
         "decoded_result": result
    }

서명 측면에서 위의 TensorFlow 모델은 다음과 같이 요약될 수 있습니다.

  • 서명

    • 키: 인코딩
    • 입력: {"x"}
    • 출력: {"encoded_result"}
  • 서명

    • 키: 디코딩
    • 입력: {"x"}
    • 출력: {"decoded_result"}

서명이 있는 모델 변환

TensorFlow Lite 변환기 API는 위의 서명 정보를 변환된 TensorFlow Lite 모델로 가져옵니다.

이 변환 기능은 TensorFlow 버전 2.7.0부터 모든 변환기 API에서 사용할 수 있습니다. 사용 예를 참조하세요.

SavedModel에서

model = Model()

# Save the model
SAVED_MODEL_PATH = 'content/saved_models/coding'

tf.saved_model.save(
    model, SAVED_MODEL_PATH,
    signatures={
      'encode': model.encode.get_concrete_function(),
      'decode': model.decode.get_concrete_function()
    })

# Convert the saved model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: content/saved_models/coding/assets
{'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} }
2022-12-15 01:06:32.364466: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-15 01:06:32.364500: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-15 01:06:32.409841: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2046] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexAsString, FlexStringToNumber
Details:
    tf.AsString(tensor<?xf32>) -> (tensor<?x!tf_type.string>) : {device = "", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64}
    tf.StringToNumber(tensor<?x!tf_type.string>) -> (tensor<?xf32>) : {device = "", out_type = f32}
See instructions: https://www.tensorflow.org/lite/guide/ops_select
INFO: Created TensorFlow Lite delegate for select TF ops.
INFO: TfLiteFlexDelegate delegate: 1 nodes delegated out of 1 nodes with 1 partitions.

Keras 모델에서

# Generate a Keras model.
keras_model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),
        tf.keras.layers.Dense(1, activation='relu', name='output'),
    ]
)

# Convert the keras model using TFLiteConverter.
# Keras model converter API uses the default signature automatically.
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)

signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp24we75_2/assets
{'serving_default': {'inputs': ['x_input'], 'outputs': ['output']} }
2022-12-15 01:06:33.126947: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-15 01:06:33.126990: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.

구체적인 기능에서

model = Model()

# Convert the concrete functions using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [model.encode.get_concrete_function(),
     model.decode.get_concrete_function()], model)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_model = converter.convert()

# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpl2ma5ilo/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpl2ma5ilo/assets
{'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} }
2022-12-15 01:06:33.323787: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2022-12-15 01:06:33.323827: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2022-12-15 01:06:33.361670: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2046] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexAsString, FlexStringToNumber
Details:
    tf.AsString(tensor<?xf32>) -> (tensor<?x!tf_type.string>) : {device = "", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64}
    tf.StringToNumber(tensor<?x!tf_type.string>) -> (tensor<?xf32>) : {device = "", out_type = f32}
See instructions: https://www.tensorflow.org/lite/guide/ops_select

서명 실행

TensorFlow 추론 API는 서명 기반 실행을 지원합니다.

  • 서명으로 지정된 입력 및 출력 이름을 통해 입력/출력 텐서에 액세스합니다.
  • 서명 키로 식별되는 그래프의 각 진입점을 별도로 실행합니다.
  • SavedModel의 초기화 절차를 지원합니다.

Java, C++ 및 Python 언어 바인딩을 현재 사용할 수 있습니다. 아래 섹션의 예를 참조하세요.

Java

try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
  // Run encoding signature.
  Map&lt;String, Object&gt; inputs = new HashMap&lt;&gt;();
  inputs.put("x", input);
  Map&lt;String, Object&gt; outputs = new HashMap&lt;&gt;();
  outputs.put("encoded_result", encoded_result);
  interpreter.runSignature(inputs, outputs, "encode");

  // Run decoding signature.
  Map&lt;String, Object&gt; inputs = new HashMap&lt;&gt;();
  inputs.put("x", encoded_result);
  Map&lt;String, Object&gt; outputs = new HashMap&lt;&gt;();
  outputs.put("decoded_result", decoded_result);
  interpreter.runSignature(inputs, outputs, "decode");
}

C++

SignatureRunner* encode_runner =
    interpreter-&gt;GetSignatureRunner("encode");
encode_runner-&gt;ResizeInputTensor("x", {100});
encode_runner-&gt;AllocateTensors();

TfLiteTensor* input_tensor = encode_runner-&gt;input_tensor("x");
float* input = input_tensor-&gt;data.f;
// Fill `input`.

encode_runner-&gt;Invoke();

const TfLiteTensor* output_tensor = encode_runner-&gt;output_tensor(
    "encoded_result");
float* output = output_tensor-&gt;data.f;
// Access `output`.

Python

# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)

# Print the signatures from the converted model
signatures = interpreter.get_signature_list()
print('Signature:', signatures)

# encode and decode are callable with input as arguments.
encode = interpreter.get_signature_runner('encode')
decode = interpreter.get_signature_runner('decode')

# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.
input = tf.constant([1, 2, 3], dtype=tf.float32)
print('Input:', input)
encoded = encode(x=input)
print('Encoded result:', encoded)
decoded = decode(x=encoded['encoded_result'])
print('Decoded result:', decoded)
Signature: {'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']} }
Input: tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
Encoded result: {'encoded_result': array([b'1.000000', b'2.000000', b'3.000000'], dtype=object)}
Decoded result: {'decoded_result': array([1., 2., 3.], dtype=float32)}

알려진 제한 사항

  • TFLite 인터프리터는 스레드 안전을 보장하지 않으므로 동일한 인터프리터의 서명 실행자는 동시에 실행되지 않습니다.
  • C/iOS/Swift에 대한 지원은 아직 제공되지 않습니다.

업데이트

  • 버전 2.7
    • 다중 서명 기능이 구현됩니다.
    • 버전 2의 모든 변환기 API는 서명이 지원되는 TensorFlow Lite 모델을 생성합니다.
  • 버전 2.5
    • 서명 기능은 from_saved_model 변환기 API를 통해 사용할 수 있습니다.