El componente de canalización Trainer TFX

El componente de canalización Trainer TFX entrena un modelo de TensorFlow.

Entrenador y TensorFlow

Trainer hace un uso extensivo de la API Python TensorFlow para entrenar modelos.

Componente

El entrenador toma:

  • tf.Ejemplos utilizados para entrenamiento y evaluación.
  • Un archivo de módulo proporcionado por el usuario que define la lógica del entrenador.
  • Definición de Protobuf de argumentos de tren y argumentos de evaluación.
  • (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y, opcionalmente, modificado por el desarrollador.
  • (Opcional) gráfico de transformación producido por un componente Transform ascendente.
  • (Opcional) Modelos previamente entrenados utilizados para escenarios como el inicio en caliente.
  • Hiperparámetros (opcionales), que se pasarán a la función del módulo de usuario. Los detalles de la integración con Tuner se pueden encontrar aquí .

El entrenador emite: al menos un modelo para inferencia/publicación (normalmente en SavedModelFormat) y, opcionalmente, otro modelo para evaluación (normalmente un EvalSavedModel).

Brindamos soporte para formatos de modelos alternativos como TFLite a través de la Biblioteca de reescritura de modelos . Consulte el enlace a la Biblioteca de reescritura de modelos para ver ejemplos de cómo convertir los modelos Estimator y Keras.

Entrenador genérico

El entrenador genérico permite a los desarrolladores utilizar cualquier API modelo de TensorFlow con el componente Trainer. Además de los estimadores de TensorFlow, los desarrolladores pueden utilizar modelos Keras o bucles de entrenamiento personalizados. Para obtener más información, consulte el RFC para el entrenador genérico .

Configuración del componente de entrenador

El código DSL de canalización típico para el Entrenador genérico se vería así:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer invoca un módulo de capacitación, que se especifica en el parámetro module_file . En lugar de trainer_fn , se requiere run_fn en el archivo del módulo si GenericExecutor se especifica en custom_executor_spec . trainer_fn fue responsable de crear el modelo. Además de eso, run_fn también necesita manejar la parte de entrenamiento y enviar el modelo entrenado a la ubicación deseada proporcionada por FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Aquí hay un archivo de módulo de ejemplo con run_fn .

Tenga en cuenta que si el componente Transformar no se utiliza en el proceso, el formador tomará los ejemplos de EjemploGen directamente:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Más detalles están disponibles en la referencia de Trainer API .