¡Únase a la comunidad de SIG TFX-Addons y ayude a que TFX sea aún mejor!
Se usó la API de Cloud Translation para traducir esta página.
Switch to English

El componente de canalización Trainer TFX

El componente de canalización Trainer TFX entrena un modelo de TensorFlow.

Entrenador y TensorFlow

Trainer hace un uso extensivo de la API de Python TensorFlow para entrenar modelos.

Componente

El entrenador toma:

  • tf.Ejemplos utilizados para entrenamiento y evaluación.
  • Un archivo de módulo proporcionado por el usuario que define la lógica del entrenador.
  • Definición de protobuf de argumentos de tren y argumentos de evaluación.
  • (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y opcionalmente alterado por el desarrollador.
  • (Opcional) gráfico de transformación producido por un componente de transformación ascendente.
  • (Opcional) modelos entrenados previamente utilizados para escenarios como inicio en caliente.
  • Hiperparámetros (opcionales), que se pasarán a la función del módulo de usuario. Los detalles de la integración con Tuner se pueden encontrar aquí .

El entrenador emite: al menos un modelo para inferencia / servicio (normalmente en SavedModelFormat) y, opcionalmente, otro modelo para eval (normalmente un EvalSavedModel).

Brindamos soporte para formatos de modelos alternativos como TFLite a través de la biblioteca de reescritura de modelos . Consulte el enlace a la biblioteca de reescritura de modelos para ver ejemplos de cómo convertir los modelos Estimator y Keras.

Entrenador genérico

El entrenador genérico permite a los desarrolladores usar cualquier modelo de API de TensorFlow con el componente de entrenador. Además de los Estimadores de TensorFlow, los desarrolladores pueden usar modelos de Keras o ciclos de entrenamiento personalizados. Para obtener más información, consulte el RFC para entrenador genérico .

Configuración del componente de entrenador

El código DSL de canalización típico para el entrenador genérico se vería así:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer invoca un módulo de formación, que se especifica en el parámetro module_file . En lugar de trainer_fn , se requiere un run_fn en el archivo del módulo si se especifica GenericExecutor en custom_executor_spec . El trainer_fn fue el responsable de crear el modelo. Además de eso, run_fn también necesita manejar la parte de entrenamiento y generar el modelo entrenado en la ubicación deseada dada por FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Aquí hay un archivo de módulo de ejemplo con run_fn .

Tenga en cuenta que si el componente Transform no se usa en la canalización, entonces el Entrenador tomaría los ejemplos de ExampleGen directamente:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Entrenador basado en estimador (obsoleto)

Para obtener información sobre el uso de un modelo basado en Estimator con TFX y Trainer, consulte Diseño de código de modelado de TensorFlow con tf.Estimator para TFX .

Configuración de un componente de entrenador para usar el ejecutor basado en Estimador

El código típico de Python DSL de canalización se ve así:

from tfx.components import Trainer
from tfx.components.trainer.executor import Executor
from tfx.dsl.components.base import executor_spec

...

trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      base_model=latest_model_resolver.outputs['latest_model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer invoca un módulo de formación, que se especifica en el parámetro module_file . Un módulo de formación típico se ve así:

# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.

  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.

  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
  warm_start_from = trainer_fn_args.base_model[
      0] if trainer_fn_args.base_model else None

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }