Join the SIG TFX-Addons community and help make TFX even better!

tfx.v1.components.Trainer

A TFX component to train a TensorFlow model.

Used in the notebooks

Used in the tutorials

The Trainer component is used to train and eval a model using given inputs and a user-supplied run_fn function.

An example of run_fn() can be found in the user-supplied code of the TFX penguin pipeline example.

Example 1: Training locally

# Uses user-provided Python function that trains a model using TF.
trainer = Trainer(
    module_file=module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=proto.TrainArgs(splits=['train'], num_steps=10000),
    eval_args=proto.EvalArgs(splits=['eval'], num_steps=5000))

Example 2: Training through a cloud provider

from tfx.extensions.google_cloud_ai_platform.trainer import executor as
ai_platform_trainer_executor
# Train using Google Cloud AI Platform.
trainer = Trainer(
    custom_executor_spec=executor_spec.ExecutorClassSpec(
        ai_platform_trainer_executor.GenericExecutor),
    module_file=module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=proto.TrainArgs(splits=['train'], num_steps=10000),
    eval_args=proto.EvalArgs(splits=['eval'], num_steps=5000))

Component outputs contains:

Please see the Trainer guide for more details.

examples A Channel of type standard_artifacts.Examples, serving as the source of examples used in training (required). May be raw or transformed.
transformed_examples Deprecated field. Please set 'examples' instead.
transform_graph An optional Channel of type standard_artifacts.TransformGraph, serving as the input transform graph if present.
schema An optional Channel of type standard_artifacts.Schema, serving as the schema of training and eval data. Schema is optional when

1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded.

base_model A Channel of type Model, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling.
hyperparameters A Channel of type standard_artifacts.HyperParameters, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this.
module_file A path to python module file containing UDF model definition. The module_file must implement a function named run_fn at its top level with function signature: def run_fn(trainer.fn_args_utils.FnArgs), and the trained model must be saved to FnArgs.serving_model_dir when this function is executed.

For Estimator based Executor, The module_file must implement a function named trainer_fn at its top level. The function must have the following signature. def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma EvalInputReceiver. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a RuntimeParameter for this argument is experimental.

run_fn A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a RuntimeParameter for this argument is experimental.
trainer_fn A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer uses Estimator based Executor. Use of a RuntimeParameter for this argument is experimental.
train_args A proto.TrainArgs instance or a dict, containing args used for training. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message. Default behavior (when splits is empty) is train on train split.
eval_args A proto.EvalArgs instance or a dict, containing args used for evaluation. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message. Default behavior (when splits is empty) is evaluate on eval split.
custom_config A dict which contains addtional training job parameters that will be passed into user module.
custom_executor_spec Optional custom executor spec. This is experimental and is subject to change in the future.

ValueError

  • When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied.
  • When both or neither of 'examples' and 'transformed_examples' is supplied.
  • When 'transformed_examples' is supplied but 'transform_graph' is not supplied.

outputs Component's output channel dict.