Компонент конвейера TFX Trainer

Оптимизируйте свои подборки Сохраняйте и классифицируйте контент в соответствии со своими настройками.

Компонент конвейера Trainer TFX обучает модель TensorFlow.

Тренер и TensorFlow

Тренер делает широкое использование Python TensorFlow API для подготовки моделей.

Компонент

Тренер берет:

  • tf. Примеры, используемые для обучения и eval.
  • Предоставляемый пользователем файл модуля, который определяет логику тренера.
  • Protobuf определение поезда арг и Eval арг.
  • (Необязательно) Схема данных, созданная компонентом конвейера SchemaGen и при необходимости измененная разработчиком.
  • (Необязательно) граф преобразования, созданный вышестоящим компонентом преобразования.
  • (Необязательно) предварительно обученные модели, используемые для таких сценариев, как теплый старт.
  • (Необязательно) гиперпараметры, которые будут переданы в функцию пользовательского модуля. Детали интеграции с тюнером можно найти здесь .

Тренер генерирует: по крайней мере, одну модель для вывода / обслуживания (обычно в SavedModelFormat) и, возможно, другую модель для eval (обычно EvalSavedModel).

Мы обеспечиваем поддержку альтернативных моделей форматов , такие как TFLite через модель переписывание библиотеку . См. Ссылку на Библиотеку перезаписи моделей, где приведены примеры преобразования моделей Estimator и Keras.

Универсальный тренер

Универсальный тренажер позволяет разработчикам использовать любой API-интерфейс модели TensorFlow с компонентом Trainer. В дополнение к оценщикам TensorFlow разработчики могут использовать модели Keras или настраиваемые циклы обучения. Для получения дополнительной информации, пожалуйста , см RFC для общего тренера .

Настройка компонента трейнера

Типичный конвейерный DSL-код для универсального трейнера будет выглядеть так:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Тренер вызывает учебный модуль, который указан в module_file параметра. Вместо trainer_fn , run_fn требуется в файле модуля , если GenericExecutor указан в custom_executor_spec . trainer_fn был ответственен за создание модели. В дополнении к этому, run_fn также необходимо обрабатывать учебную часть и вывести подготовку модели к желаемому месту , заданному FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Ниже приведен пример файла модуля с run_fn .

Обратите внимание, что если компонент Transform не используется в конвейере, то Trainer будет брать примеры напрямую из ExampleGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Более подробная информация доступна в справочнике API Trainer .