¡Únase a la comunidad de SIG TFX-Addons y ayude a que TFX sea aún mejor!

El componente de canalización Tuner TFX

El componente Tuner ajusta los hiperparámetros del modelo.

Componente de sintonizador y biblioteca KerasTuner

El componente Tuner hace un uso extensivo de la API de Python KerasTuner para ajustar hiperparámetros.

Componente

El sintonizador toma:

  • tf.Ejemplos utilizados para entrenamiento y evaluación.
  • Un archivo de módulo proporcionado por el usuario (o módulo fn) que define la lógica de ajuste, incluida la definición del modelo, el espacio de búsqueda de hiperparámetros, el objetivo, etc.
  • Definición de protobuf de argumentos de tren y argumentos de evaluación.
  • (Opcional) Definición de Protobuf de argumentos de ajuste.
  • (Opcional) gráfico de transformación producido por un componente de transformación ascendente.
  • (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y opcionalmente alterado por el desarrollador.

Con los datos, el modelo y el objetivo dados, Tuner sintoniza los hiperparámetros y emite el mejor resultado.

Instrucciones

Se tuner_fn función de módulo de usuario tuner_fn con la siguiente firma para Tuner:

...
from kerastuner.engine import base_tuner

TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
                                             ('fit_kwargs', Dict[Text, Any])])

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """Build the tuner using the KerasTuner API.
  Args:
    fn_args: Holds args as name/value pairs.

      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.
  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.
  """
  ...

En esta función, usted define los espacios de búsqueda del modelo y del hiperparámetro, y elige el objetivo y el algoritmo para el ajuste. El componente Tuner toma este código de módulo como entrada, sintoniza los hiperparámetros y emite el mejor resultado.

Trainer puede tomar los hiperparámetros de salida de Tuner como entrada y utilizarlos en su código de módulo de usuario. La definición de canalización se ve así:

...
tuner = Tuner(
    module_file=module_file,  # Contains `tuner_fn`.
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=20),
    eval_args=trainer_pb2.EvalArgs(num_steps=5))

trainer = Trainer(
    module_file=module_file,  # Contains `run_fn`.
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    # This will be passed to `run_fn`.
    hyperparameters=tuner.outputs['best_hyperparameters'],
    train_args=trainer_pb2.TrainArgs(num_steps=100),
    eval_args=trainer_pb2.EvalArgs(num_steps=5))
...

Es posible que no desee ajustar los hiperparámetros cada vez que vuelva a entrenar su modelo. Una vez que haya utilizado Tuner para determinar un buen conjunto de hiperparámetros, puede eliminar Tuner de su canalización y utilizar ImporterNode para importar el artefacto Tuner de una ejecución de entrenamiento anterior para alimentarlo a Trainer.

hparams_importer = ImporterNode(
    instance_name='import_hparams',
    # This can be Tuner's output file or manually edited file. The file contains
    # text format of hyperparameters (kerastuner.HyperParameters.get_config())
    source_uri='path/to/best_hyperparameters.txt',
    artifact_type=HyperParameters)

trainer = Trainer(
    ...
    # An alternative is directly use the tuned hyperparameters in Trainer's user
    # module code and set hyperparameters to None here.
    hyperparameters = hparams_importer.outputs['result'])

Ajuste en Google Cloud Platform (GCP)

Cuando se ejecuta en Google Cloud Platform (GCP), el componente Tuner puede aprovechar dos servicios:

AI Platform Vizier como el backend del ajuste de hiperparámetros

AI Platform Vizier es un servicio administrado que realiza una optimización de caja negra, basada en la tecnología Google Vizier .

CloudTuner es una implementación de KerasTuner que se comunica con el servicio AI Platform Vizier como backend del estudio. Dado que CloudTuner es una subclase de kerastuner.Tuner , puede usarse como un reemplazo tuner_fn en el módulo tuner_fn y ejecutarse como parte del componente TFX Tuner.

A continuación se muestra un fragmento de código que muestra cómo usar CloudTuner . Tenga en cuenta que la configuración de CloudTuner requiere elementos que son específicos de GCP, como project_id y region .

...
from tensorflow_cloud import CloudTuner

...
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """An implementation of tuner_fn that instantiates CloudTuner."""

  ...
  tuner = CloudTuner(
      _build_model,
      hyperparameters=...,
      ...
      project_id=...,       # GCP Project ID
      region=...,           # GCP Region where Vizier service is run.
  )

  ...
  return TuneFnResult(
      tuner=tuner,
      fit_kwargs={...}
  )

Ajuste paralelo en el grupo de trabajadores distribuidos de Cloud AI Platform Training

El marco KerasTuner como implementación subyacente del componente Tuner tiene la capacidad de realizar búsquedas de hiperparámetros en paralelo. Si bien el componente Stock Tuner no tiene la capacidad de ejecutar más de un trabajador de búsqueda en paralelo, al usar el componente Tuner de extensión de Google Cloud AI Platform , brinda la capacidad de ejecutar un ajuste paralelo, usando un trabajo de entrenamiento de AI Platform como un grupo de trabajadores distribuidos gerente. TuneArgs es la configuración dada a este componente. Este es un reemplazo directo del componente de sintonizador estándar.

tuner = google_cloud_ai_platform.Tuner(
    ...   # Same kwargs as the above stock Tuner component.
    tune_args=proto.TuneArgs(num_parallel_trials=3),  # 3-worker parallel
    custom_config={
        # Configures Cloud AI Platform-specific configs . For for details, see
        # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
        TUNING_ARGS_KEY:
            {
                'project': ...,
                'region': ...,
                # Configuration of machines for each master/worker in the flock.
                'masterConfig': ...,
                'workerConfig': ...,
                ...
            }
    })
...

El comportamiento y la salida del componente Tuner de extensión es el mismo que el del componente Tuner estándar, excepto que se ejecutan múltiples búsquedas de hiperparámetros en paralelo en diferentes máquinas trabajadoras y, como resultado, num_trials se completará más rápido. Esto es particularmente efectivo cuando el algoritmo de búsqueda es vergonzosamente paralelizable, como RandomSearch . Sin embargo, si el algoritmo de búsqueda utiliza información de resultados de ensayos anteriores, como lo hace el algoritmo Google Vizier implementado en AI Platform Vizier, una búsqueda excesivamente paralela afectaría negativamente la eficacia de la búsqueda.

Ejemplo E2E

Ejemplo de E2E CloudTuner en GCP

Tutorial de KerasTuner

Tutorial de CloudTuner

Propuesta

Hay más detalles disponibles en la referencia de la API de Tuner .