Введение
В этом руководстве показано, как Tensorflow Extended (TFX) может создавать и оценивать модели машинного обучения, которые будут развернуты на устройстве. TFX теперь обеспечивает встроенную поддержку TFLite , что позволяет выполнять высокоэффективный логический вывод на мобильных устройствах.
В этом руководстве описаны изменения, которые можно внести в любой конвейер для создания и оценки моделей TFLite. Здесь мы приводим полный пример, демонстрирующий, как TFX может обучать и оценивать модели TFLite, обученные на основе набора данных MNIST . Кроме того, мы показываем, как один и тот же конвейер можно использовать для одновременного экспорта как стандартной модели SavedModel на основе Keras , так и модели TFLite, что позволяет пользователям сравнивать их качество.
Мы предполагаем, что вы знакомы с TFX, нашими компонентами и конвейерами. Если нет, то посмотрите этот урок .
Шаги
Для создания и оценки модели TFLite в TFX требуется всего два шага. Первым шагом является вызов переписчика TFLite в контексте TFX Trainer для преобразования обученной модели TensorFlow в модель TFLite. Второй шаг — настройка Evaluator для оценки моделей TFLite. Теперь мы обсудим каждый по очереди.
Вызов рерайтера TFLite в трейнере.
TFX Trainer ожидает, что определяемый пользователем run_fn
будет указан в файле модуля. Этот run_fn
определяет модель для обучения, обучает ее для указанного количества итераций и экспортирует обученную модель.
В оставшейся части этого раздела мы предоставляем фрагменты кода, которые показывают изменения, необходимые для вызова переписчика TFLite и экспорта модели TFLite. Весь этот код находится в run_fn
модуля MNIST TFLite .
Как показано в приведенном ниже коде, мы должны сначала создать подпись, которая принимает Tensor
для каждой функции в качестве входных данных. Обратите внимание, что это отход от большинства существующих моделей в TFX, которые принимают сериализованные прототипы tf.Example в качестве входных данных.
signatures = {
'serving_default':
_get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None, 784],
dtype=tf.float32,
name='image_floats'))
}
Затем модель Keras сохраняется как SavedModel так же, как обычно.
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
Наконец, мы создаем экземпляр переписчика TFLite ( tfrw
) и вызываем его в SavedModel для получения модели TFLite. Мы сохраняем эту модель TFLite в serving_model_dir
, предоставленном вызывающей run_fn
. Таким образом, модель TFLite сохраняется в том месте, где все нижестоящие компоненты TFX ожидают ее найти.
tfrw = rewriter_factory.create_rewriter(
rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
converters.rewrite_saved_model(temp_saving_model_dir,
fn_args.serving_model_dir,
tfrw,
rewriter.ModelType.TFLITE_MODEL)
Оценка модели TFLite.
TFX Evaluator предоставляет возможность анализировать обученные модели, чтобы понять их качество по широкому диапазону показателей. Помимо анализа сохраненных моделей, TFX Evaluator теперь также может анализировать модели TFLite.
В следующем фрагменте кода (воспроизведенном из конвейера MNIST ) показано, как настроить Evaluator, который анализирует модель TFLite.
# Informs the evaluator that the model is a TFLite model.
eval_config_lite.model_specs[0].model_type = 'tf_lite'
...
# Uses TFMA to compute the evaluation statistics over features of a TFLite
# model.
model_analyzer_lite = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer_lite.outputs['model'],
eval_config=eval_config_lite,
).with_id('mnist_lite')
Как показано выше, единственное изменение, которое нам нужно сделать, это установить для поля model_type
значение tf_lite
. Никаких других изменений конфигурации для анализа модели TFLite не требуется. Независимо от того, анализируется ли модель TFLite или SavedModel, выходные данные Evaluator
будут иметь точно такую же структуру.
Однако обратите внимание, что оценщик предполагает, что модель TFLite сохранена в файле с именем tflite
внутри train_lite.outputs['model'].