Únase a nosotros en DevFest para Ucrania del 14 al 15 de junio En línea Regístrese ahora

El componente de canalización Trainer TFX

El componente de canalización Trainer TFX entrena un modelo de TensorFlow.

Entrenador y TensorFlow

Entrenador hace un amplio uso de la Python TensorFlow API para los modelos de formación.

Componente

El entrenador toma:

  • tf.Ejemplos utilizados para entrenamiento y evaluación.
  • Un archivo de módulo proporcionado por el usuario que define la lógica del entrenador.
  • Protobuf definición de args args tren y eval.
  • (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y modificado opcionalmente por el desarrollador.
  • (Opcional) gráfico de transformación producido por un componente Transform ascendente.
  • (Opcional) modelos preentrenados utilizados para escenarios como arranque en caliente.
  • (Opcional) hiperparámetros, que se pasarán a la función del módulo de usuario. Los detalles de la integración con el sintonizador se pueden encontrar aquí .

El entrenador emite: al menos un modelo para inferencia/servicio (normalmente en formato de modelo guardado) y, opcionalmente, otro modelo para evaluación (normalmente, un modelo guardado de evaluación).

Proporcionamos soporte para formatos alternativos como modelo TFLite a través de la Biblioteca Modelo reescritura . Consulte el enlace a la biblioteca de reescritura de modelos para ver ejemplos de cómo convertir modelos de Estimator y Keras.

Entrenador genérico

El entrenador genérico permite a los desarrolladores usar cualquier modelo de API de TensorFlow con el componente Entrenador. Además de TensorFlow Estimators, los desarrolladores pueden usar modelos Keras o ciclos de entrenamiento personalizados. Para más detalles, ver el RFC genérico para el entrenador .

Configuración del componente de entrenador

El código DSL de canalización típico para el Entrenador genérico se vería así:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Entrenador invoca un módulo de formación, que se especifica en el module_file parámetro. En lugar de trainer_fn , un run_fn es necesaria en el archivo de módulo si el GenericExecutor se especifica en el custom_executor_spec . El trainer_fn fue responsable de crear el modelo. Además de eso, run_fn también tiene que manejar la parte de entrenamiento y la salida del modelo entrenado a un lugar deseado dada por FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

He aquí un archivo de módulo ejemplo con run_fn .

Tenga en cuenta que si el componente Transform no se usa en la canalización, el Entrenador tomaría los ejemplos de ExampleGen directamente:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Más detalles están disponibles en la referencia de la API entrenador .