tfx.components.Trainer

View source on GitHub

A TFX component to train a TensorFlow model.

Inherits From: BaseComponent

Used in the notebooks

Used in the tutorials

The Trainer component is used to train and eval a model using given inputs and a user-supplied estimator.

Providing an estimator

The TFX executor will use the estimator provided in the module_file file to train the model. The Trainer executor will look specifically for the trainer_fn() function within that file. Before training, the executor will call that function expecting the following returned as a dictionary:

  • estimator: The estimator to be used by TensorFlow to train the model.
  • train_spec: The configuration to be used by the "train" part of the TensorFlow train_and_evaluate() call.
  • eval_spec: The configuration to be used by the "eval" part of the TensorFlow train_and_evaluate() call.
  • eval_input_receiver_fn: The configuration to be used by the ModelValidator component when validating the model.

An example of trainer_fn() can be found in the user-supplied code of the TFX Chicago Taxi pipeline example.

Please see https://www.tensorflow.org/guide/estimators for more details.

Example 1: Training locally

# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
    module_file=module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Example 2: Training through a cloud provider

# Train using Google Cloud AI Platform.
trainer = Trainer(
    executor_class=ai_platform_trainer_executor.Executor,
    module_file=module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=infer_schema.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

examples A Channel of type standard_artifacts.Examples, serving as the source of examples used in training (required). May be raw or transformed.
transformed_examples Deprecated field. Please set 'examples' instead.
transform_graph An optional Channel of type standard_artifacts.TransformGraph, serving as the input transform graph if present.
schema A Channel of type standard_artifacts.Schema, serving as the schema of training and eval data.
base_model A Channel of type Model, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling.
hyperparameters A Channel of type standard_artifacts.HyperParameters, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this.
module_file A path to python module file containing UDF model definition.

For default executor, The module_file must implement a function named trainer_fn at its top level. The function must have the following signature.

def trainer_fn(trainer.executor.TrainerFnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ...

where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver. Exactly one of 'module_file' or 'trainer_fn' must be supplied.

For generic executor, The module_file must implement a function named run_fn at its top level with function signature: def run_fn(trainer.executor.TrainerFnArgs), and the trained model must be saved to TrainerFnArgs.serving_model_dir when execute this function.

run_fn A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor.
trainer_fn A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied.
train_args A trainer_pb2.TrainArgs instance or a dict, containing args used for training. Current only num_steps is available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message.
eval_args A trainer_pb2.EvalArgs instance or a dict, containing args used for evaluation. Current only num_steps is available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message.
custom_config A dict which contains addtional training job parameters that will be passed into user module.
custom_executor_spec Optional custom executor spec.
output Optional Model channel for result of exported models.
transform_output Backwards compatibility alias for the 'transform_graph' argument.
instance_name Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline.

ValueError

  • When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied.
  • When both or neither of 'examples' and 'transformed_examples' is supplied.
  • When 'transformed_examples' is supplied but 'transform_graph' is not supplied.

component_id DEPRECATED FUNCTION

component_type DEPRECATED FUNCTION
downstream_nodes

exec_properties

id Node id, unique across all TFX nodes in a pipeline.

If instance name is available, node_id will be: . otherwise, node_id will be:

inputs

outputs

type

upstream_nodes

Child Classes

class DRIVER_CLASS

class SPEC_CLASS

Methods

add_downstream_node

View source

Experimental: Add another component that must run after this one.

This method enables task-based dependencies by enforcing execution order for synchronous pipelines on supported platforms. Currently, the supported platforms are Airflow, Beam, and Kubeflow Pipelines.

Note that this API call should be considered experimental, and may not work with asynchronous pipelines, sub-pipelines and pipelines with conditional nodes. We also recommend relying on data for capturing dependencies where possible to ensure data lineage is fully captured within MLMD.

It is symmetric with add_upstream_node.

Args
downstream_node a component that must run after this node.

add_upstream_node

View source

Experimental: Add another component that must run before this one.

This method enables task-based dependencies by enforcing execution order for synchronous pipelines on supported platforms. Currently, the supported platforms are Airflow, Beam, and Kubeflow Pipelines.

Note that this API call should be considered experimental, and may not work with asynchronous pipelines, sub-pipelines and pipelines with conditional nodes. We also recommend relying on data for capturing dependencies where possible to ensure data lineage is fully captured within MLMD.

It is symmetric with add_downstream_node.

Args
upstream_node a component that must run before this node.

from_json_dict

View source

Convert from dictionary data to an object.

get_id

View source

Gets the id of a node.

This can be used during pipeline authoring time. For example: from tfx.components import Trainer

resolver = ResolverNode(..., model=Channel( type=Model, producer_component_id=Trainer.get_id('my_trainer')))

Args
instance_name (Optional) instance name of a node. If given, the instance name will be taken into consideration when generating the id.

Returns
an id for the node.

to_json_dict

View source

Convert from an object to a JSON serializable dictionary.

Class Variables

  • EXECUTOR_SPEC