tfx.v1.extensions.google_cloud_ai_platform.Trainer
Stay organized with collections
Save and categorize content based on your preferences.
Cloud AI Platform Trainer component.
Inherits From: Trainer
, BaseComponent
, BaseNode
tfx.v1.extensions.google_cloud_ai_platform.Trainer(
examples: Optional[tfx.v1.dsl.Channel
] = None,
transformed_examples: Optional[tfx.v1.dsl.Channel
] = None,
transform_graph: Optional[tfx.v1.dsl.Channel
] = None,
schema: Optional[tfx.v1.dsl.Channel
] = None,
base_model: Optional[tfx.v1.dsl.Channel
] = None,
hyperparameters: Optional[tfx.v1.dsl.Channel
] = None,
module_file: Optional[Union[str, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
run_fn: Optional[Union[str, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
trainer_fn: Optional[Union[str, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
train_args: Optional[Union[tfx.v1.proto.TrainArgs
, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
eval_args: Optional[Union[tfx.v1.proto.EvalArgs
, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
custom_config: Optional[Dict[str, Any]] = None
)
Used in the notebooks
Args |
examples
|
A Channel of type standard_artifacts.Examples , serving as the
source of examples used in training (required). May be raw or
transformed.
|
transformed_examples
|
Deprecated field. Please set examples instead.
|
transform_graph
|
An optional Channel of type
standard_artifacts.TransformGraph , serving as the input transform
graph if present.
|
schema
|
An optional Channel of type standard_artifacts.Schema , serving
as the schema of training and eval data. Schema is optional when 1)
transform_graph is provided which contains schema. 2) user module
bypasses the usage of schema, e.g., hardcoded.
|
base_model
|
A Channel of type Model , containing model that will be used
for training. This can be used for warmstart, transfer learning or model
ensembling.
|
hyperparameters
|
A Channel of type standard_artifacts.HyperParameters ,
serving as the hyperparameters for training module. Tuner's output best
hyperparameters can be feed into this.
|
module_file
|
A path to python module file containing UDF model definition.
The module_file must implement a function named run_fn at its top
level with function signature: def
run_fn(trainer.fn_args_utils.FnArgs) , and the trained model must be
saved to FnArgs.serving_model_dir when this function is executed. For
Estimator based Executor, The module_file must implement a function
named trainer_fn at its top level. The function must have the
following signature. def trainer_fn(trainer.fn_args_utils.FnArgs,
tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ...
where the returned Dict has the following key-values.
'estimator': an instance of tf.estimator.Estimator
'train_spec': an instance of tf.estimator.TrainSpec
'eval_spec': an instance of tf.estimator.EvalSpec
'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
|
run_fn
|
A python path to UDF model definition function for generic
trainer. See 'module_file' for details. Exactly one of 'module_file' or
'run_fn' must be supplied if Trainer uses GenericExecutor (default).
|
trainer_fn
|
A python path to UDF model definition function for estimator
based trainer. See 'module_file' for the required signature of the UDF.
Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
uses Estimator based Executor
|
train_args
|
A proto.TrainArgs instance, containing args used for training
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is train on train split.
|
eval_args
|
A proto.EvalArgs instance, containing args used for evaluation.
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is evaluate on eval split.
|
custom_config
|
A dict which contains addtional training job parameters
that will be passed into user module.
|
Attributes |
outputs
|
Component's output channel dict.
|
Methods
with_node_execution_options
with_node_execution_options(
node_execution_options: utils.NodeExecutionOptions
) -> typing_extensions.Self
Class Variables |
POST_EXECUTABLE_SPEC
|
None
|
PRE_EXECUTABLE_SPEC
|
None
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-05-03 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-05-03 UTC."],[],[]]