Registration is open for TensorFlow Dev Summit 2020 Learn more

tfx.components.trainer.executor.Executor

View source on GitHub

Class Executor

Local trainer used by the TFX Trainer component.

Inherits From: BaseExecutor

The Trainer executor supplements TensorFlow training with a component to enable warm-start training of any user-specified tf.estimator. The Trainer is a library built on top of TensorFlow that is expected to be integrated into a custom user-specified binary.

To include Trainer in a TFX pipeline, configure your pipeline similar to https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_simple.py#L104.

For more details on the Trainer component itself, please refer to https://tensorflow.org/tfx/guide/trainer. For a tutorial on TF Estimator, please refer to https://www.tensorflow.org/extend/estimators.

How to create a trainer callback function to be used by this Trainer executor: An estimator can be executed by TFX by first creating a trainer_fn callback method that returns an estimator and some additional parameters, similar to https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py#L285. This becomes the basis of the new Executor for Trainer. This Executor will then train and evaluate this estimator using the tf.estimator.train_and_evaluate API to train locally.

__init__

View source

__init__(context=None)

Constructs a beam based executor.

Child Classes

class Context

Methods

Do

View source

Do(
    input_dict,
    output_dict,
    exec_properties
)

Uses a user-supplied tf.estimator to train a TensorFlow model locally.

The Trainer Executor invokes a training_fn callback function provided by the user via the module_file parameter. With the tf.estimator returned by this function, the Trainer Executor then builds a TensorFlow model using the user-provided tf.estimator.

Args:

  • input_dict: Input dict from input key to a list of ML-Metadata Artifacts.
    • examples: Examples used for training, must include 'train' and 'eval' splits.
    • transform_output: Optional input transform graph.
    • schema: Schema of the data.
  • output_dict: Output dict from output key to a list of Artifacts.
    • output: Exported model.
  • exec_properties: A dict of execution properties.
    • train_args: JSON string of trainer_pb2.TrainArgs instance, providing args for training.
    • eval_args: JSON string of trainer_pb2.EvalArgs instance, providing args for eval.
    • module_file: Python module file containing UDF model definition.
    • warm_starting: Whether or not we need to do warm starting.
    • warm_start_from: Optional. If warm_starting is True, this is the directory to find previous model to warm start on.

Returns:

None

Raises:

  • ValueError: When neither or both of 'module_file' and 'trainer_fn' are present in 'exec_properties'.