Le composant du pipeline Trainer TFX

Le composant de pipeline Trainer TFX entraîne un modèle TensorFlow.

Entraîneur et TensorFlow

Trainer utilise largement l'API Python TensorFlow pour les modèles de formation.

Composant

Le formateur prend :

  • tf.Exemples utilisés pour la formation et l'évaluation.
  • Un fichier de module fourni par l'utilisateur qui définit la logique du formateur.
  • Définition Protobuf des arguments de train et des arguments d'évaluation.
  • (Facultatif) Un schéma de données créé par un composant de pipeline SchemaGen et éventuellement modifié par le développeur.
  • (Facultatif) graphique de transformation produit par un composant Transform en amont.
  • (Facultatif) modèles pré-entraînés utilisés pour des scénarios tels que le démarrage à chaud.
  • Hyperparamètres (facultatifs), qui seront transmis à la fonction du module utilisateur. Les détails de l'intégration avec Tuner peuvent être trouvés ici .

Le formateur émet : au moins un modèle pour l'inférence/le service (généralement dans SavedModelFormat) et éventuellement un autre modèle pour l'évaluation (généralement un EvalSavedModel).

Nous prenons en charge d'autres formats de modèles tels que TFLite via la bibliothèque de réécriture de modèles . Consultez le lien vers la bibliothèque de réécriture de modèles pour obtenir des exemples de conversion des modèles Estimator et Keras.

Entraîneur générique

L'entraîneur générique permet aux développeurs d'utiliser n'importe quelle API de modèle TensorFlow avec le composant Trainer. En plus des estimateurs TensorFlow, les développeurs peuvent utiliser des modèles Keras ou des boucles de formation personnalisées. Pour plus de détails, veuillez consulter la RFC pour l'entraîneur générique .

Configuration du composant formateur

Le code DSL de pipeline typique pour le formateur générique ressemblerait à ceci :

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer appelle un module de formation, spécifié dans le paramètre module_file . Au lieu de trainer_fn , un run_fn est requis dans le fichier de module si GenericExecutor est spécifié dans custom_executor_spec . Le trainer_fn était responsable de la création du modèle. En plus de cela, run_fn doit également gérer la partie formation et afficher le modèle formé à l'emplacement souhaité indiqué par FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Voici un exemple de fichier de module avec run_fn .

Notez que si le composant Transform n'est pas utilisé dans le pipeline, alors le formateur prendra directement les exemples de SampleGen :

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Plus de détails sont disponibles dans la référence de l'API Trainer .