Join the SIG TFX-Addons community and help make TFX even better!


Local generic trainer executor for the TFX Trainer component.

Inherits From: BaseExecutor

The Trainer executor supplements TensorFlow training with a component to enable warm-start training of any user-specified TF model. The Trainer is a library built on top of TensorFlow that is expected to be integrated into a custom user-specified binary.

To include Trainer in a TFX pipeline, configure your pipeline similar to

For more details on the Trainer component itself, please refer to For a tutorial on Tensorflow, please refer to

How to create a trainer callback function to be used by this Trainer executor: A model training can be executed by TFX by first creating a run_fn callback method that defines, trains an TF Model and saves it to the provided location, This becomes the basis of the Executor for GenericTrainer. This Executor will then execute the run_fn with correct parameters by resolving the input artifacts, output artifacts and execution properties.

Child Classes

class Context



View source

Uses a user-supplied run_fn to train a TensorFlow model locally.

The Trainer Executor invokes a run_fn callback function provided by the user via the module_file parameter. In this function, user defines the model and trains it, then saves the model and training related files (e.g, Tensorboard logs) to the provided locations.

input_dict Input dict from input key to a list of ML-Metadata Artifacts.

  • examples: Examples used for training, must include 'train' and 'eval' if custom splits is not specified in train_args and eval_args.
  • transform_graph: Optional input transform graph.
  • transform_output: Optional input transform graph, deprecated.
  • schema: Schema of the data.
output_dict Output dict from output key to a list of Artifacts.
  • model: Exported model.
  • model_run: Model training related outputs (e.g., Tensorboard logs)
  • exec_properties A dict of execution properties.
  • train_args: JSON string of trainer_pb2.TrainArgs instance, providing args for training.
  • eval_args: JSON string of trainer_pb2.EvalArgs instance, providing args for eval.
  • module_file: Python module file containing UDF model definition. Exactly one of module_file, module_path and run_fn should be passed.
  • module_path: Python module path containing UDF model definition. Exactly one of module_file, module_path and run_fn should be passed.
  • run_fn: Python module path to the run function. Exactly one of module_file, module_path and run_fn should be passed.
  • warm_starting: Whether or not we need to do warm starting.
  • warm_start_from: Optional. If warm_starting is True, this is the directory to find previous model to warm start on.
  • custom_config: Optional. JSON-serialized dict of additional parameters to pass to trainer function.
  • Returns

    ValueError When not exactly one of module_file, module_path and run_fn are present in 'exec_properties'.
    RuntimeError If run_fn failed to generate model in desired location.