简介
本指南演示如何使用 TensorFlow Extended (TFX) 创建和评估在设备端部署的机器学习模型。现在,TFX 为 TFLite 提供原生支持,这使得在移动设备上执行高效推断成为可能。
本指南将引导您对任何流水线进行更改以生成和评估 TFLite 模型。我们在此处提供了一个完整示例,演示 TFX 如何训练和评估使用 MNIST 数据集训练的 TFLite 模型。此外,我们还将展示如何使用同一个流水线同时导出标准的基于 Keras 的 SavedModel 和 TFLite SavedModel,使用户能够比较二者的质量。
我们假设您熟悉 TFX、相应组件和流水线。如果您不熟悉这些内容,请参阅此教程。
步骤
在 TFX 中创建和评估 TFLite 模型只需两个步骤。第一步是在 TFX Trainer 上下文中调用 TFLite 重写器,将训练的 TensorFlow 模型转换为 TFLite 模型。第二步是配置 Evaluator 以评估 TFLite 模型。现在,我们依次讨论这两个步骤。
在 Trainer 中调用 TFLite 重写器
TFX Trainer 要求在模块文件中指定用户定义的 run_fn
。此 run_fn
定义要训练的模型,对其进行指定迭代次数的训练,并导出训练后的模型。
我们将在本节的其余部分提供代码段,展示调用 TFLite 重写器和导出 TFLite 模型所需的更改。所有这些代码都位于 MNIST TFLite 模块的 run_fn
中。
如以下代码所示,我们必须首先创建一个签名,该签名会为每个特征接受 Tensor
作为输入。请注意,这与 TFX 中的大多数现有模型不同,现有模型接受序列化的 tf.Example proto 作为输入。
signatures = {
'serving_default':
_get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None, 784],
dtype=tf.float32,
name='image_floats'))
}
然后,使用与平时相同的方式将 Keras 模型保存为 SavedModel。
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
最后,我们创建 TFLite 重写器 (tfrw
) 的实例,并在 SavedModel 上调用它以获取 TFLite 模型。我们将此 TFLite 模型存储在 run_fn
的调用者提供的 serving_model_dir
中。这样,TFLite 模型就会存储在所有下游 TFX 组件预期查找模型的位置。
tfrw = rewriter_factory.create_rewriter(
rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
converters.rewrite_saved_model(temp_saving_model_dir,
fn_args.serving_model_dir,
tfrw,
rewriter.ModelType.TFLITE_MODEL)
评估 TFLite 模型
TFX Evaluator 可以分析训练的模型,以了解模型在各种指标上的质量。除了分析 SavedModel 之外,TFX Evaluator 现在还可以分析 TFLite 模型。
以下代码段(从 MNIST 流水线复制)展示了如何配置分析 TFLite 模型的 Evaluator。
# Informs the evaluator that the model is a TFLite model.
eval_config_lite.model_specs[0].model_type = 'tf_lite'
...
# Uses TFMA to compute the evaluation statistics over features of a TFLite
# model.
model_analyzer_lite = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer_lite.outputs['model'],
eval_config=eval_config_lite,
instance_name='mnist_lite')
如上所示,我们只需将 model_type
字段更改为 tf_lite
。无需其他配置更改即可分析 TFLite 模型。无论是分析 TFLite 模型还是 SavedModel,Evaluator
的输出都将具有完全相同的结构。