tfx.v1.extensions.google_cloud_ai_platform.Trainer

Cloud AI Platform Trainer component.

Inherits From: Trainer, BaseComponent, BaseNode

Used in the notebooks

Used in the tutorials

examples A Channel of type standard_artifacts.Examples, serving as the source of examples used in training (required). May be raw or transformed.
transformed_examples Deprecated field. Please set examples instead.
transform_graph An optional Channel of type standard_artifacts.TransformGraph, serving as the input transform graph if present.
schema An optional Channel of type standard_artifacts.Schema, serving as the schema of training and eval data. Schema is optional when 1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded.
base_model A Channel of type Model, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling.
hyperparameters A Channel of type standard_artifacts.HyperParameters, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this.
module_file A path to python module file containing UDF model definition. The module_file must implement a function named run_fn at its top level with function signature: def run_fn(trainer.fn_args_utils.FnArgs), and the trained model must be saved to FnArgs.serving_model_dir when this function is executed. For Estimator based Executor, The module_file must implement a function named trainer_fn at its top level. The function must have the following signature. def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
run_fn A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default).
trainer_fn A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer uses Estimator based Executor
train_args A proto.TrainArgs instance, containing args used for training Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on train split.
eval_args A proto.EvalArgs instance, containing args used for evaluation. Currently only splits and num_steps are available. Default behavior (when splits is empty) is evaluate on eval split.
custom_config A dict which contains addtional training job parameters that will be passed into user module.

outputs Component's output channel dict.