¡El Día de la Comunidad de ML es el 9 de noviembre! Únase a nosotros para recibir actualizaciones de TensorFlow, JAX, y más Más información

El componente de canalización Tuner TFX

El componente Tuner ajusta los hiperparámetros del modelo.

Componente de sintonizador y biblioteca KerasTuner

El componente sintonizador hace un amplio uso de la Python KerasTuner API para hiperparámetros de ajuste.

Componente

El sintonizador toma:

  • tf.Ejemplos utilizados para entrenamiento y evaluación.
  • Un archivo de módulo proporcionado por el usuario (o módulo fn) que define la lógica de ajuste, incluida la definición del modelo, el espacio de búsqueda de hiperparámetros, el objetivo, etc.
  • Protobuf definición de args args tren y eval.
  • (Opcional) Protobuf definición de argumentos de ajuste.
  • (Opcional) gráfico de transformación producido por un componente de transformación ascendente.
  • (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y opcionalmente alterado por el desarrollador.

Con los datos, el modelo y el objetivo dados, Tuner sintoniza los hiperparámetros y emite el mejor resultado.

Instrucciones

Una función módulo de usuario tuner_fn se requiere con la siguiente firma para el sintonizador:

...
from keras_tuner.engine import base_tuner

TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
                                             ('fit_kwargs', Dict[Text, Any])])

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """Build the tuner using the KerasTuner API.
  Args:
    fn_args: Holds args as name/value pairs.

      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.
  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.
  """
  ...

En esta función, usted define los espacios de búsqueda del modelo y del hiperparámetro, y elige el objetivo y el algoritmo para el ajuste. El componente Tuner toma este código de módulo como entrada, sintoniza los hiperparámetros y emite el mejor resultado.

Trainer puede tomar los hiperparámetros de salida de Tuner como entrada y utilizarlos en su código de módulo de usuario. La definición de canalización se ve así:

...
tuner = Tuner(
    module_file=module_file,  # Contains `tuner_fn`.
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=20),
    eval_args=trainer_pb2.EvalArgs(num_steps=5))

trainer = Trainer(
    module_file=module_file,  # Contains `run_fn`.
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    # This will be passed to `run_fn`.
    hyperparameters=tuner.outputs['best_hyperparameters'],
    train_args=trainer_pb2.TrainArgs(num_steps=100),
    eval_args=trainer_pb2.EvalArgs(num_steps=5))
...

Es posible que no desee ajustar los hiperparámetros cada vez que vuelva a entrenar su modelo. Una vez que haya utilizado sintonizador para determinar un buen conjunto de hiperparámetros, puede quitar el sintonizador de su tubería y el uso ImporterNode importar el artefacto sintonizador de un entrenamiento previo a ejecutar alimentación para el instructor.

hparams_importer = ImporterNode(
    instance_name='import_hparams',
    # This can be Tuner's output file or manually edited file. The file contains
    # text format of hyperparameters (keras_tuner.HyperParameters.get_config())
    source_uri='path/to/best_hyperparameters.txt',
    artifact_type=HyperParameters)

trainer = Trainer(
    ...
    # An alternative is directly use the tuned hyperparameters in Trainer's user
    # module code and set hyperparameters to None here.
    hyperparameters = hparams_importer.outputs['result'])

Ajuste en Google Cloud Platform (GCP)

Cuando se ejecuta en Google Cloud Platform (GCP), el componente Tuner puede aprovechar dos servicios:

AI Platform Vizier como el backend del ajuste de hiperparámetros

AI Plataforma Visir es un servicio gestionado que realiza la optimización cuadro negro, basado en el Google Visir tecnología.

CloudTuner es una implementación de KerasTuner que habla con el servicio de IA Plataforma Visir como backend estudio. Desde CloudTuner es una subclase de keras_tuner.Tuner , que puede ser utilizado como una gota en el reemplazo en el tuner_fn módulo, y ejecutar como una parte del componente de TFX Tuner.

A continuación se muestra un fragmento de código que muestra cómo utilizar CloudTuner . Tenga en cuenta que la configuración de CloudTuner requiere elementos que son específicos de GCP, como el project_id y region .

...
from tensorflow_cloud import CloudTuner

...
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """An implementation of tuner_fn that instantiates CloudTuner."""

  ...
  tuner = CloudTuner(
      _build_model,
      hyperparameters=...,
      ...
      project_id=...,       # GCP Project ID
      region=...,           # GCP Region where Vizier service is run.
  )

  ...
  return TuneFnResult(
      tuner=tuner,
      fit_kwargs={...}
  )

Ajuste paralelo en el grupo de trabajadores distribuidos de Cloud AI Platform Training

El marco KerasTuner como implementación subyacente del componente Tuner tiene la capacidad de realizar búsquedas de hiperparámetros en paralelo. Mientras que el componente de sintonizador no tiene capacidad de ejecutar más de un trabajador de búsqueda en paralelo, utilizando el componente de extensión sintonizador Google Cloud Platform AI , que proporciona la capacidad de ejecutar la sintonización en paralelo, utilizando una Capacitación Laboral plataforma IA como una bandada trabajador distribuido gerente. TuneArgs es la configuración dada a este componente. Este es un reemplazo directo del componente de sintonizador estándar.

tuner = google_cloud_ai_platform.Tuner(
    ...   # Same kwargs as the above stock Tuner component.
    tune_args=proto.TuneArgs(num_parallel_trials=3),  # 3-worker parallel
    custom_config={
        # Configures Cloud AI Platform-specific configs . For for details, see
        # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
        TUNING_ARGS_KEY:
            {
                'project': ...,
                'region': ...,
                # Configuration of machines for each master/worker in the flock.
                'masterConfig': ...,
                'workerConfig': ...,
                ...
            }
    })
...

El comportamiento y la salida del componente de extensión Tuner es el mismo que el componente de sintonizador, excepto que múltiples búsquedas hiperparámetro se ejecutan en paralelo en diferentes máquinas de los trabajadores, y como resultado, los num_trials se completará más rápidamente. Esto es particularmente efectivo cuando el algoritmo de búsqueda es vergonzosamente paralelizable, como RandomSearch . Sin embargo, si el algoritmo de búsqueda utiliza información de resultados de ensayos anteriores, como lo hace el algoritmo Google Vizier implementado en AI Platform Vizier, una búsqueda excesivamente paralela afectaría negativamente la eficacia de la búsqueda.

Ejemplo E2E

Ejemplo de E2E CloudTuner en GCP

Tutorial de KerasTuner

Tutorial de CloudTuner

Propuesta

Más detalles están disponibles en la referencia de la API de sintonizador .