Die Pipeline-Komponente Trainer TFX trainiert ein TensorFlow-Modell.
Trainer und TensorFlow
Der Trainer nutzt die Python TensorFlow- API in großem Umfang für Trainingsmodelle.
Komponente
Trainer nimmt:
- Beispiele für Schulung und Bewertung.
- Eine vom Benutzer bereitgestellte Moduldatei, die die Trainerlogik definiert.
- Protobuf- Definition von Zugargumenten und Bewertungsargumenten.
- (Optional) Ein Datenschema, das von einer SchemaGen-Pipelinekomponente erstellt und optional vom Entwickler geändert wurde.
- (Optionaler) Transformationsgraph, der von einer vorgeschalteten Transformationskomponente erzeugt wird.
- (Optional) vorab trainierte Modelle für Szenarien wie Warmstart.
- (Optionale) Hyperparameter, die an die Funktion des Benutzermoduls übergeben werden. Details zur Integration mit Tuner finden Sie hier .
Der Trainer gibt Folgendes aus: Mindestens ein Modell für Inferenz / Serving (normalerweise in SavedModelFormat) und optional ein anderes Modell für eval (normalerweise ein EvalSavedModel).
Wir bieten Unterstützung für alternative Modellformate wie TFLite über die Model Rewriting Library . Beispiele zum Konvertieren von Estimator- und Keras-Modellen finden Sie unter dem Link zur Model Rewriting Library.
Schätzer-basierter Trainer
Informationen zur Verwendung eines Estimator- basierten Modells mit TFX und Trainer finden Sie unter Entwerfen von TensorFlow-Modellierungscode mit tf.Estimator für TFX .
Konfigurieren einer Trainerkomponente
Der typische Python-DSL-Pipeline-Code sieht folgendermaßen aus:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
base_model=latest_model_resolver.outputs['latest_model'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Der Trainer ruft ein Trainingsmodul auf, das im Parameter module_file
angegeben ist. Ein typisches Trainingsmodul sieht folgendermaßen aus:
# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
"""Build the estimator using the high level API.
Args:
trainer_fn_args: Holds args used to train the model as name/value pairs.
schema: Holds the schema of the training examples.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
# Number of nodes in the first layer of the DNN
first_dnn_layer_size = 100
num_dnn_layers = 4
dnn_decay_factor = 0.7
train_batch_size = 40
eval_batch_size = 40
tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)
train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.train_files,
tf_transform_output,
batch_size=train_batch_size)
eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.eval_files,
tf_transform_output,
batch_size=eval_batch_size)
train_spec = tf.estimator.TrainSpec( # pylint: disable=g-long-lambda
train_input_fn,
max_steps=trainer_fn_args.train_steps)
serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_output, schema)
exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(
eval_input_fn,
steps=trainer_fn_args.eval_steps,
exporters=[exporter],
name='chicago-taxi-eval')
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=999, keep_checkpoint_max=1)
run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
warm_start_from = trainer_fn_args.base_model[
0] if trainer_fn_args.base_model else None
estimator = _build_estimator(
# Construct layers sizes with exponetial decay
hidden_units=[
max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
for i in range(num_dnn_layers)
],
config=run_config,
warm_start_from=warm_start_from)
# Create an input receiver for TFMA processing
receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_output, schema)
return {
'estimator': estimator,
'train_spec': train_spec,
'eval_spec': eval_spec,
'eval_input_receiver_fn': receiver_fn
}
Generischer Trainer
Mit dem generischen Trainer können Entwickler jede TensorFlow-Modell-API mit der Trainer-Komponente verwenden. Zusätzlich zu TensorFlow Estimators können Entwickler Keras-Modelle oder benutzerdefinierte Trainingsschleifen verwenden. Einzelheiten finden Sie im RFC für generische Trainer .
Konfigurieren der Trainerkomponente für die Verwendung des GenericExecutor
Der typische Pipeline-DSL-Code für den generischen Trainer sieht folgendermaßen aus:
from tfx.components import Trainer
from tfx.dsl.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
...
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Der Trainer ruft ein Trainingsmodul auf, das im Parameter module_file
angegeben ist. Anstelle von trainer_fn
ist in der Moduldatei ein run_fn
erforderlich, wenn der GenericExecutor
in custom_executor_spec
. Der trainer_fn
war für die Erstellung des Modells verantwortlich. Darüber hinaus muss run_fn
auch den Trainingsteil bearbeiten und das trainierte Modell an den von FnArgs angegebenen gewünschten Ort ausgeben :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Hier ist eine Beispielmoduldatei mit run_fn
.
Beachten Sie, dass der Trainer die Beispiele direkt aus ExampleGen übernimmt, wenn die Transformationskomponente nicht in der Pipeline verwendet wird:
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))