¡Únase a la comunidad de SIG TFX-Addons y ayude a que TFX sea aún mejor!

El componente de canalización Trainer TFX

El componente de canalización Trainer TFX entrena un modelo de TensorFlow.

Entrenador y TensorFlow

Trainer hace un uso extensivo de la API de Python TensorFlow para entrenar modelos.

Componente

El entrenador toma:

  • tf.Ejemplos utilizados para entrenamiento y evaluación.
  • Un archivo de módulo proporcionado por el usuario que define la lógica del entrenador.
  • Definición de protobuf de argumentos de tren y argumentos de evaluación.
  • (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y, opcionalmente, alterado por el desarrollador.
  • (Opcional) gráfico de transformación producido por un componente de transformación ascendente.
  • (Opcional) modelos entrenados previamente utilizados para escenarios como inicio en caliente.
  • Hiperparámetros (opcionales), que se pasarán a la función del módulo de usuario. Los detalles de la integración con Tuner se pueden encontrar aquí .

Entrenador emite: al menos un modelo para inferencia / servicio (normalmente en SavedModelFormat) y, opcionalmente, otro modelo para eval (normalmente un EvalSavedModel).

Brindamos soporte para formatos de modelos alternativos como TFLite a través de la biblioteca de reescritura de modelos . Consulte el enlace a la biblioteca de reescritura de modelos para ver ejemplos de cómo convertir los modelos Estimator y Keras.

Entrenador genérico

El entrenador genérico permite a los desarrolladores utilizar cualquier modelo de API de TensorFlow con el componente de entrenador. Además de los Estimadores de TensorFlow, los desarrolladores pueden usar modelos de Keras o ciclos de entrenamiento personalizados. Para obtener más información, consulte el RFC para entrenador genérico .

Configuración del componente de entrenador

El código DSL de canalización típico para el entrenador genérico se vería así:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer invoca un módulo de formación, que se especifica en el parámetro module_file . En lugar de trainer_fn , se requiere un run_fn en el archivo del módulo si se especifica GenericExecutor en custom_executor_spec . El trainer_fn fue el responsable de crear el modelo. Además de eso, run_fn también necesita manejar la parte de entrenamiento y generar el modelo entrenado en la ubicación deseada dada por FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Aquí hay unarchivo de módulo de ejemplo con run_fn .

Tenga en cuenta que si el componente Transform no se utiliza en la canalización, el Entrenador tomaría los ejemplos de ExampleGen directamente:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Hay más detalles disponibles en la referencia de la API de entrenador .