tfx.v1.components.Tuner

A TFX component for model hyperparameter tuning.

Inherits From: BaseComponent, BaseNode

Component outputs contains:

  • best_hyperparameters: Channel of type standard_artifacts.HyperParameters for result of the best hparams.
  • tuner_results: Channel of type standard_artifacts.TunerResults for results of all trials. Experimental: subject to change and no backwards compatibility guarantees.

See the Tuner guide for more details.

examples A BaseChannel of type standard_artifacts.Examples, serving as the source of examples that are used in tuning (required).
schema An optional BaseChannel of type standard_artifacts.Schema, serving as the schema of training and eval data. This is used when raw examples are provided.
transform_graph An optional BaseChannel of type standard_artifacts.TransformGraph, serving as the input transform graph if present. This is used when transformed examples are provided.
base_model A BaseChannel of type Model, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling.
module_file A path to python module file containing UDF tuner definition. The module_file must implement a function named tuner_fn at its top level. The function must have the following signature. def tuner_fn(fn_args: FnArgs) -> TunerFnResult: Exactly one of 'module_file' or 'tuner_fn' must be supplied.
tuner_fn A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'tuner_fn' must be supplied.
train_args A trainer_pb2.TrainArgs instance, containing args used for training. Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on train split.
eval_args A trainer_pb2.EvalArgs instance, containing args used for eval. Currently only splits and num_steps are available. Default behavior (when splits is empty) is evaluate on eval split.
tune_args A tuner_pb2.TuneArgs instance, containing args used for tuning. Currently only num_parallel_trials is available.
custom_config A dict which contains addtional training job parameters that will be passed into user module.

outputs Component's output channel dict.