Treten Sie der SIG TFX-Addons-Community bei und helfen Sie mit, TFX noch besser zu machen! SIG TFX-Addons beitreten

Die Trainer TFX Pipeline-Komponente

Die Trainer TFX-Pipelinekomponente trainiert ein TensorFlow-Modell.

Trainer und TensorFlow

Trainer nutzt umfassend die Python TensorFlow- API zum Trainieren von Modellen.

Komponente

Trainer nimmt:

  • tf.Beispiele für Training und Evaluation.
  • Eine vom Benutzer bereitgestellte Moduldatei, die die Trainerlogik definiert.
  • Protobuf- Definition von Zugargumenten und Auswertungsargumenten.
  • (Optional) Ein Datenschema, das von einer SchemaGen-Pipelinekomponente erstellt und optional vom Entwickler geändert wird.
  • (Optional) Transformationsdiagramm, das von einer vorgelagerten Transformationskomponente erstellt wurde.
  • (Optional) vortrainierte Modelle, die für Szenarien wie Warmstart verwendet werden.
  • (Optional) Hyperparameter, die an die Benutzermodulfunktion übergeben werden. Details zur Integration mit Tuner finden Sie hier .

Der Trainer gibt aus: Mindestens ein Modell für Inferenz/Serving (normalerweise im SavedModelFormat) und optional ein weiteres Modell für die Bewertung (normalerweise ein EvalSavedModel).

Wir bieten Unterstützung für alternative Modellformate wie TFLite über die Model Rewriting Library . Unter dem Link zur Model Rewriting Library finden Sie Beispiele zum Konvertieren von Estimator- und Keras-Modellen.

Generische Trainer

Der generische Trainer ermöglicht Entwicklern die Verwendung einer beliebigen TensorFlow-Modell-API mit der Trainer-Komponente. Zusätzlich zu TensorFlow Estimators können Entwickler Keras-Modelle oder benutzerdefinierte Trainingsschleifen verwenden. Weitere Informationen finden Sie im RFC für generische Trainer .

Konfiguration der Trainerkomponente

Ein typischer Pipeline-DSL-Code für den generischen Trainer würde so aussehen:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer ruft ein Trainingsmodul auf, das im Parameter module_file angegeben ist. Anstelle von trainer_fn ist in der Moduldatei ein run_fn erforderlich, wenn der GenericExecutor in der custom_executor_spec . Der trainer_fn war für die Erstellung des Modells verantwortlich. Darüber hinaus muss run_fn auch den Trainingsteil verarbeiten und das trainierte Modell an den von FnArgs angegebenen gewünschten Ort ausgeben :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Hier ist eineBeispiel-Moduldatei mit run_fn .

Beachten Sie, dass der Trainer die Beispiele direkt aus ExampleGen übernimmt, wenn die Transform-Komponente in der Pipeline nicht verwendet wird:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Weitere Details finden Sie in der Trainer-API-Referenz .