将您的 TFLite 代码迁移到 TF2

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow Lite (TFLite) 是一套工具,可帮助开发者在设备端(移动、嵌入式和物联网设备)上运行机器学习推断。TFLite 转换器可以将现有 TF 模型转换为可在设备端高效运行的优化 TFLite 模型格式。

在本文档中,您将了解需要对 TF 到 TFLite 的转换代码进行哪些更改,然后是几个实现相同目标的示例。

TF 到 TFLite 转换代码的更改

  • 如果您使用的是旧版 TF1 模型格式(例如,Keras 文件、冻结的 GraphDef、检查点、tf.Session 等),请将其更新为 TF1/TF2 SavedModel,并使用 TF2 转换器 API tf.lite.TFLiteConverter.from_saved_model(...) 将其转换为 TFLite 模型(请参见表 1)。

  • 更新转换器 API 标志(请参见表 2)。

  • 移除旧版 API,例如 tf.lite.constants。(例如:将 tf.lite.constants.INT8 替换为 tf.int8

// 表 1 // TFLite Python 转换器 API 更新

TF1 API TF2 API
tf.lite.TFLiteConverter.from_saved_model('saved_model/',..) 支持
tf.lite.TFLiteConverter.from_keras_model_file('model.h5',..) 已移除(更新为 SavedModel 格式)
tf.lite.TFLiteConverter.from_frozen_graph('model.pb',..) 已移除(更新为 SavedModel 格式)
tf.lite.TFLiteConverter.from_session(sess,...) 已移除(更新为 SavedModel 格式)

<style> .table {margin-left: 0 !important;} </style>

// 表 2 // TFLite Python 转换器 API 标志更新

TF1 API TF2 API
allow_custom_ops
optimizations
representative_dataset
target_spec
inference_input_type
inference_output_type
experimental_new_converter
experimental_new_quantizer
支持







input_tensors
output_tensors
input_arrays_with_shape
output_arrays
experimental_debug_info_func
已移除(不支持的转换器 API 参数)




change_concat_input_ranges
default_ranges_stats
get_input_arrays()
inference_type
quantized_input_stats
reorder_across_fake_quant
已移除(不支持的量化工作流)





conversion_summary_dir
dump_graphviz_dir
dump_graphviz_video
已移除(改为使用 Netronvisualize.py 呈现模型)


output_format
drop_control_dependency
已移除(TF2 中不支持的功能)

示例

您现在将演练一些示例,将旧版 TF1 模型转换为 TF1/TF2 SavedModel,然后将其转换为 TF2 TFLite 模型。

安装

从必要的 TensorFlow 导入开始。

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import numpy as np

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

import shutil
def remove_dir(path):
  try:
    shutil.rmtree(path)
  except:
    pass
2023-11-07 19:56:49.282900: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 19:56:49.282948: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 19:56:49.284710: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

创建所有必要的 TF1 模型格式。

# Create a TF1 SavedModel
SAVED_MODEL_DIR = "tf_saved_model/"
remove_dir(SAVED_MODEL_DIR)
with tf1.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=(3,), name='input')
    output = input + 2
    # print("result: ", sess.run(output, {input: [0., 2., 4.]}))
    tf1.saved_model.simple_save(
        sess, SAVED_MODEL_DIR,
        inputs={'input': input}, 
        outputs={'output': output})
print("TF1 SavedModel path: ", SAVED_MODEL_DIR)

# Create a TF1 Keras model
KERAS_MODEL_PATH = 'tf_keras_model.h5'
model = tf1.keras.models.Sequential([
    tf1.keras.layers.InputLayer(input_shape=(128, 128, 3,), name='input'),
    tf1.keras.layers.Dense(units=16, input_shape=(128, 128, 3,), activation='relu'),
    tf1.keras.layers.Dense(units=1, name='output')
])
model.save(KERAS_MODEL_PATH, save_format='h5')
print("TF1 Keras Model path: ", KERAS_MODEL_PATH)

# Create a TF1 frozen GraphDef model
GRAPH_DEF_MODEL_PATH = tf.keras.utils.get_file(
    'mobilenet_v1_0.25_128',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz',
    untar=True,
) + '/frozen_graph.pb'

print("TF1 frozen GraphDef path: ", GRAPH_DEF_MODEL_PATH)
TF1 SavedModel path:  tf_saved_model/
TF1 Keras Model path:  tf_keras_model.h5
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz
2617289/2617289 [==============================] - 0s 0us/step
TF1 frozen GraphDef path:  /home/kbuilder/.keras/datasets/mobilenet_v1_0.25_128/frozen_graph.pb
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(

1. 将 TF1 SavedModel 转换为 TFLite 模型

之前:使用 TF1 进行转换

下面是 TF1 样式 TFlite 转换的典型代码。

converter = tf1.lite.TFLiteConverter.from_saved_model(
    saved_model_dir=SAVED_MODEL_DIR,
    input_arrays=['input'],
    input_shapes={'input' : [3]}
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
# Ignore warning: "Use '@tf.function' or '@defun' to decorate the function."
2023-11-07 19:56:54.397156: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 19:56:54.397194: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 19:56:54.397201: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 5, % non-converted = 20.00 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (f32: 1)



  (f32: 1)

之后:使用 TF2 进行转换

将 TF1 SavedModel 直接转换为 TFLite 模型,并设置较小的 v2 转换器标志。

# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=SAVED_MODEL_DIR)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()
2023-11-07 19:56:54.461613: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 19:56:54.461651: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 5, % non-converted = 20.00 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (f32: 1)



  (f32: 1)

2. 将 TF1 Keras 模型文件转换为 TFLite 模型

之前:使用 TF1 进行转换

下面是 TF1 样式 TFlite 转换的典型代码。

converter = tf1.lite.TFLiteConverter.from_keras_model_file(model_file=KERAS_MODEL_PATH)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
2023-11-07 19:56:55.482788: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 19:56:55.482826: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 19:56:55.482833: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 9, Total Ops 35, % non-converted = 25.71 %
 * 9 ARITH ops

- arith.constant:    9 occurrences  (f32: 4, i32: 5)



  (f32: 2)
  (i32: 2)
  (f32: 2)
  (i32: 4)

  (i32: 2)
  (i32: 4)
  (f32: 4)
  (i32: 2)

之后:使用 TF2 进行转换

首先,将 TF1 Keras 模型文件转换为 TF2 SavedModel,然后将其转换为 TFLite 模型,并设置较小的 v2 转换器标志。

# Convert TF1 Keras model file to TF2 SavedModel.
model = tf.keras.models.load_model(KERAS_MODEL_PATH)
model.save(filepath='saved_model_2/')

# Convert TF2 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_2/')
tflite_model = converter.convert()
2023-11-07 19:56:56.076436: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 19:56:56.076492: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 9, Total Ops 35, % non-converted = 25.71 %
 * 9 ARITH ops

- arith.constant:    9 occurrences  (f32: 4, i32: 5)



  (f32: 2)
  (i32: 2)
  (f32: 2)
  (i32: 4)

  (i32: 2)
  (i32: 4)
  (f32: 4)
  (i32: 2)

3. 将 TF1 冻结的 GraphDef 转换为 TFLite 模型

之前:使用 TF1 进行转换

下面是 TF1 样式 TFlite 转换的典型代码。

converter = tf1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file=GRAPH_DEF_MODEL_PATH,
    input_arrays=['input'],
    input_shapes={'input' : [1, 128, 128, 3]},
    output_arrays=['MobilenetV1/Predictions/Softmax'],
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
2023-11-07 19:56:56.312103: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 19:56:56.312152: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 19:56:56.312159: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 38, Total Ops 91, % non-converted = 41.76 %
 * 38 ARITH ops

- arith.constant:   38 occurrences  (f32: 37, i32: 1)



  (f32: 1)
  (f32: 15)
  (f32: 13)
  (uq_8: 19)
  (f32: 1)
  (f32: 1)

之后:使用 TF2 进行转换

首先,将 TF1 冻结的 GraphDef 转换为 TF1 SavedModel,然后将其转换为 TFLite 模型,并设置较小的 v2 转换器标志。

## Convert TF1 frozen Graph to TF1 SavedModel.

# Load the graph as a v1.GraphDef
import pathlib
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(pathlib.Path(GRAPH_DEF_MODEL_PATH).read_bytes())

# Convert the GraphDef to a tf.Graph
with tf.Graph().as_default() as g:
  tf.graph_util.import_graph_def(gdef, name="")

# Look up the input and output tensors.
input_tensor = g.get_tensor_by_name('input:0') 
output_tensor = g.get_tensor_by_name('MobilenetV1/Predictions/Softmax:0')

# Save the graph as a TF1 Savedmodel
remove_dir('saved_model_3/')
with tf.compat.v1.Session(graph=g) as s:
  tf.compat.v1.saved_model.simple_save(
      session=s,
      export_dir='saved_model_3/',
      inputs={'input':input_tensor},
      outputs={'output':output_tensor})

# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_3/')
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()
2023-11-07 19:56:57.205923: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 19:56:57.205972: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 38, Total Ops 91, % non-converted = 41.76 %
 * 38 ARITH ops

- arith.constant:   38 occurrences  (f32: 37, i32: 1)



  (f32: 1)
  (f32: 15)
  (f32: 13)
  (uq_8: 19)
  (f32: 1)
  (f32: 1)

延伸阅读

  • 请参阅 TFLite 指南来详细了解工作流和最新功能。
  • 如果您使用的是 TF1 代码或旧版 TF1 模型格式(Keras .h5 文件、冻结的 GraphDef .pb 等),请更新您的代码并将您的模型迁移到 TF2 SavedModel 格式