![]() |
A TFX component to train a TensorFlow model.
Inherits From: BaseComponent
, BaseNode
tfx.components.Trainer(
examples: tfx.types.Channel
= None,
transformed_examples: Optional[tfx.types.Channel
] = None,
transform_graph: Optional[tfx.types.Channel
] = None,
schema: Optional[tfx.types.Channel
] = None,
base_model: Optional[tfx.types.Channel
] = None,
hyperparameters: Optional[tfx.types.Channel
] = None,
module_file: Optional[Union[Text, tfx.orchestration.data_types.RuntimeParameter
]] = None,
run_fn: Optional[Union[Text, tfx.orchestration.data_types.RuntimeParameter
]] = None,
trainer_fn: Optional[Union[Text, tfx.orchestration.data_types.RuntimeParameter
]] = None,
train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
custom_config: Optional[Dict[Text, Any]] = None,
custom_executor_spec: Optional[tfx.dsl.components.base.executor_spec.ExecutorSpec
] = None,
output: Optional[tfx.types.Channel
] = None,
model_run: Optional[tfx.types.Channel
] = None,
transform_output: Optional[tfx.types.Channel
] = None,
instance_name: Optional[Text] = None
)
Used in the notebooks
Used in the tutorials |
---|
The Trainer component is used to train and eval a model using given inputs and a user-supplied estimator.
Providing an estimator
The TFX executor will use the estimator provided in the module_file
file
to train the model. The Trainer executor will look specifically for the
trainer_fn()
function within that file. Before training, the executor will
call that function expecting the following returned as a dictionary:
- estimator: The estimator to be used by TensorFlow to train the model.
- train_spec: The
configuration
to be used by the "train" part of the TensorFlow
train_and_evaluate()
call. - eval_spec: The
configuration
to be used by the "eval" part of the TensorFlow
train_and_evaluate()
call. - eval_input_receiver_fn: The configuration to be used by the ModelValidator component when validating the model.
An example of trainer_fn()
can be found in the user-supplied
code
of the TFX Chicago Taxi pipeline example.
Please see https://www.tensorflow.org/guide/estimators for more details.
Example 1: Training locally
# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
module_file=module_file,
transformed_examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=10000),
eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=5000))
Example 2: Training through a cloud provider
from tfx.extensions.google_cloud_ai_platform.trainer import executor as
ai_platform_trainer_executor
# Train using Google Cloud AI Platform.
trainer = Trainer(
custom_executor_spec=executor_spec.ExecutorClassSpec(
ai_platform_trainer_executor.Executor),
module_file=module_file,
transformed_examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=10000),
eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=5000))
Args | |
---|---|
examples
|
A Channel of type standard_artifacts.Examples , serving as
the source of examples used in training (required). May be raw or
transformed.
|
transformed_examples
|
Deprecated field. Please set 'examples' instead. |
transform_graph
|
An optional Channel of type
standard_artifacts.TransformGraph , serving as the input transform
graph if present.
|
schema
|
An optional Channel of type standard_artifacts.Schema , serving
as the schema of training and eval data. Schema is optional when
1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded. |
base_model
|
A Channel of type Model , containing model that will be used
for training. This can be used for warmstart, transfer learning or
model ensembling.
|
hyperparameters
|
A Channel of type standard_artifacts.HyperParameters ,
serving as the hyperparameters for training module. Tuner's output best
hyperparameters can be feed into this.
|
module_file
|
A path to python module file containing UDF model definition.
For default executor, The module_file must implement a function named
def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver. Exactly one of 'module_file' or 'trainer_fn' must be supplied. For generic executor, The module_file must implement a function named
|
run_fn
|
A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor. |
trainer_fn
|
A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. |
train_args
|
A trainer_pb2.TrainArgs instance or a dict, containing args
used for training. Currently only splits and num_steps are available. If
it's provided as a dict and any field is a RuntimeParameter, it should
have the same field names as a TrainArgs proto message. Default
behavior (when splits is empty) is train on train split.
|
eval_args
|
A trainer_pb2.EvalArgs instance or a dict, containing args
used for evaluation. Currently only splits and num_steps are available.
If it's provided as a dict and any field is a RuntimeParameter, it
should have the same field names as a EvalArgs proto message. Default
behavior (when splits is empty) is evaluate on eval split.
|
custom_config
|
A dict which contains addtional training job parameters that will be passed into user module. |
custom_executor_spec
|
Optional custom executor spec. |
output
|
Optional Model channel for result of exported models.
|
model_run
|
Optional ModelRun channel, as the working dir of models,
can be used to output non-model related output (e.g., TensorBoard logs).
|
transform_output
|
Backwards compatibility alias for the 'transform_graph' argument. |
instance_name
|
Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline. |
Raises | |
---|---|
ValueError
|
|
Attributes | |
---|---|
component_id
|
|
component_type
|
|
downstream_nodes
|
|
exec_properties
|
|
id
|
Node id, unique across all TFX nodes in a pipeline.
If |
inputs
|
|
outputs
|
|
type
|
|
upstream_nodes
|
Child Classes
Methods
add_downstream_node
add_downstream_node(
downstream_node
)
Experimental: Add another component that must run after this one.
This method enables task-based dependencies by enforcing execution order for synchronous pipelines on supported platforms. Currently, the supported platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work with asynchronous pipelines, sub-pipelines and pipelines with conditional nodes. We also recommend relying on data for capturing dependencies where possible to ensure data lineage is fully captured within MLMD.
It is symmetric with add_upstream_node
.
Args | |
---|---|
downstream_node
|
a component that must run after this node. |
add_upstream_node
add_upstream_node(
upstream_node
)
Experimental: Add another component that must run before this one.
This method enables task-based dependencies by enforcing execution order for synchronous pipelines on supported platforms. Currently, the supported platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work with asynchronous pipelines, sub-pipelines and pipelines with conditional nodes. We also recommend relying on data for capturing dependencies where possible to ensure data lineage is fully captured within MLMD.
It is symmetric with add_downstream_node
.
Args | |
---|---|
upstream_node
|
a component that must run before this node. |
from_json_dict
@classmethod
from_json_dict( dict_data: Dict[Text, Any] ) -> Any
Convert from dictionary data to an object.
get_id
@classmethod
get_id( instance_name: Optional[Text] = None )
Gets the id of a node.
This can be used during pipeline authoring time. For example: from tfx.components import Trainer
resolver = ResolverNode(..., model=Channel( type=Model, producer_component_id=Trainer.get_id('my_trainer')))
Args | |
---|---|
instance_name
|
(Optional) instance name of a node. If given, the instance name will be taken into consideration when generating the id. |
Returns | |
---|---|
an id for the node. |
to_json_dict
to_json_dict() -> Dict[Text, Any]
Convert from an object to a JSON serializable dictionary.
with_id
with_id(
id: Text
) -> "BaseNode"
with_platform_config
with_platform_config(
config: message.Message
) -> "BaseComponent"
Attaches a proto-form platform config to a component.
The config will be a per-node platform-specific config.
Args | |
---|---|
config
|
platform config to attach to the component. |
Returns | |
---|---|
the same component itself. |
Class Variables | |
---|---|
EXECUTOR_SPEC |
tfx.dsl.components.base.executor_spec.ExecutorClassSpec
|