The Tuner TFX Pipeline Component

The Tuner component tunes the hyperparameters for the model.

Tuner Component and KerasTuner Library

The Tuner component makes extensive use of the Python KerasTuner API for tuning hyperparameters.


Tuner takes:

  • tf.Examples used for training and eval.
  • A user provided module file (or module fn) that defines the tuning logic, including model definition, hyperparameter search space, objective etc.
  • Protobuf definition of train args and eval args.
  • (Optional) Protobuf definition of tuning args.
  • (Optional) transform graph produced by an upstream Transform component.
  • (Optional) A data schema created by a SchemaGen pipeline component and optionally altered by the developer.

With the given data, model, and objective, Tuner tunes the hyperparameters and emits the best result.


A user module function tuner_fn with the following signature is required for Tuner:

from keras_tuner.engine import base_tuner

TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
                                             ('fit_kwargs', Dict[Text, Any])])

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """Build the tuner using the KerasTuner API.
    fn_args: Holds args as name/value pairs.
      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.

In this function, you define both the model and hyperparameter search spaces, and choose the objective and algorithm for tuning. The Tuner component takes this module code as input, tunes the hyperparameters, and emits the best result.

Trainer can take Tuner's output hyperparameters as input and utilize them in its user module code. The pipeline definition looks like this:

tuner = Tuner(
    module_file=module_file,  # Contains `tuner_fn`.

trainer = Trainer(
    module_file=module_file,  # Contains `run_fn`.
    # This will be passed to `run_fn`.

You might not want to tune the hyperparameters every time you retrain your model. Once you have used Tuner to determine a good set of hyperparameters, you can remove Tuner from your pipeline and use ImporterNode to import the Tuner artifact from a previous training run to feed to Trainer.

hparams_importer = Importer(
    # This can be Tuner's output file or manually edited file. The file contains
    # text format of hyperparameters (keras_tuner.HyperParameters.get_config())

trainer = Trainer(
    # An alternative is directly use the tuned hyperparameters in Trainer's user
    # module code and set hyperparameters to None here.
    hyperparameters = hparams_importer.outputs['result'])

Tuning on Google Cloud Platform (GCP)

When running on the Google Cloud Platform (GCP), the Tuner component can take advantage of two services:

AI Platform Vizier as the backend of hyperparameter tuning

AI Platform Vizier is a managed service that performs black box optimization, based on the Google Vizier technology.

CloudTuner is an implementation of KerasTuner which talks to the AI Platform Vizier service as the study backend. Since CloudTuner is a subclass of keras_tuner.Tuner, it can be used as a drop-in replacement in the tuner_fn module, and execute as a part of the TFX Tuner component.

Below is a code snippet which shows how to use CloudTuner. Notice that configuration to CloudTuner requires items which are specific to GCP, such as the project_id and region.

from tensorflow_cloud import CloudTuner

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """An implementation of tuner_fn that instantiates CloudTuner."""

  tuner = CloudTuner(
      project_id=...,       # GCP Project ID
      region=...,           # GCP Region where Vizier service is run.

  return TuneFnResult(

Parallel tuning on Cloud AI Platform Training distributed worker flock

The KerasTuner framework as the underlying implementation of the Tuner component has ability to conduct hyperparameter search in parallel. While the stock Tuner component does not have ability to execute more than one search worker in parallel, by using the Google Cloud AI Platform extension Tuner component, it provides the ability to run parallel tuning, using an AI Platform Training Job as a distributed worker flock manager. TuneArgs is the configuration given to this component. This is a drop-in replacement of the stock Tuner component.

tuner = google_cloud_ai_platform.Tuner(
    ...   # Same kwargs as the above stock Tuner component.
    tune_args=proto.TuneArgs(num_parallel_trials=3),  # 3-worker parallel
        # Configures Cloud AI Platform-specific configs . For for details, see
                'project': ...,
                'region': ...,
                # Configuration of machines for each master/worker in the flock.
                'masterConfig': ...,
                'workerConfig': ...,

The behavior and the output of the extension Tuner component is the same as the stock Tuner component, except that multiple hyperparameter searches are executed in parallel on different worker machines, and as a result, the num_trials will be completed faster. This is particularly effective when the search algorithm is embarrassingly parallelizable, such as RandomSearch. However, if the search algorithm uses information from results of prior trials, such as Google Vizier algorithm implemented in the AI Platform Vizier does, an excessively parallel search would negatively affect the efficacy of the search.

E2E Example

E2E CloudTuner on GCP Example

KerasTuner tutorial

CloudTuner tutorial


More details are available in the Tuner API reference.