Компонент конвейера Trainer TFX обучает модель TensorFlow.
Тренер и TensorFlow
Тренер делает широкое использование Python TensorFlow API для подготовки моделей.
Компонент
Тренер берет:
- tf. Примеры, используемые для обучения и eval.
- Предоставляемый пользователем файл модуля, который определяет логику тренера.
- Protobuf определение поезда арг и Eval арг.
- (Необязательно) Схема данных, созданная компонентом конвейера SchemaGen и при необходимости измененная разработчиком.
- (Необязательно) граф преобразования, созданный вышестоящим компонентом преобразования.
- (Необязательно) предварительно обученные модели, используемые для таких сценариев, как теплый старт.
- (Необязательно) гиперпараметры, которые будут переданы в функцию пользовательского модуля. Детали интеграции с тюнером можно найти здесь .
Тренер генерирует: по крайней мере, одну модель для вывода / обслуживания (обычно в SavedModelFormat) и, возможно, другую модель для eval (обычно EvalSavedModel).
Мы обеспечиваем поддержку альтернативных моделей форматов , такие как TFLite через модель переписывание библиотеку . См. Ссылку на Библиотеку перезаписи моделей, где приведены примеры преобразования моделей Estimator и Keras.
Универсальный тренер
Универсальный тренажер позволяет разработчикам использовать любой API-интерфейс модели TensorFlow с компонентом Trainer. В дополнение к оценщикам TensorFlow разработчики могут использовать модели Keras или настраиваемые циклы обучения. Для получения дополнительной информации, пожалуйста , см RFC для общего тренера .
Настройка компонента трейнера
Типичный конвейерный DSL-код для универсального трейнера будет выглядеть так:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Тренер вызывает учебный модуль, который указан в module_file
параметра. Вместо trainer_fn
, run_fn
требуется в файле модуля , если GenericExecutor
указан в custom_executor_spec
. trainer_fn
был ответственен за создание модели. В дополнении к этому, run_fn
также необходимо обрабатывать учебную часть и вывести подготовку модели к желаемому месту , заданному FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Ниже приведен пример файла модуля с run_fn
.
Обратите внимание, что если компонент Transform не используется в конвейере, то Trainer будет брать примеры напрямую из ExampleGen:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Более подробная информация доступна в справочнике API Trainer .