Rejoignez la communauté SIG TFX-Addons et contribuez à rendre TFX encore meilleur ! Rejoignez SIG TFX-Addons

Le Composant Pipeline TFX Trainer

Le composant de pipeline Trainer TFX entraîne un modèle TensorFlow.

Formateur et TensorFlow

Trainer utilise largement l'API Python TensorFlow pour entraîner les modèles.

Composant

Le formateur prend :

  • tf.Exemples utilisés pour la formation et l'évaluation.
  • Un fichier de module fourni par l'utilisateur qui définit la logique du formateur.
  • Définition Protobuf des arguments de train et des arguments d'évaluation .
  • (Facultatif) Un schéma de données créé par un composant de pipeline SchemaGen et éventuellement modifié par le développeur.
  • (Facultatif) graphe de transformation produit par un composant Transform en amont.
  • (Facultatif) modèles pré-entraînés utilisés pour des scénarios tels que le démarrage à chaud.
  • (Facultatif) hyperparamètres, qui seront transmis à la fonction du module utilisateur. Les détails de l'intégration avec Tuner peuvent être trouvés ici .

Le formateur émet : au moins un modèle pour l'inférence/la diffusion (généralement dans SavedModelFormat) et éventuellement un autre modèle pour eval (généralement un EvalSavedModel).

Nous prenons en charge les formats de modèles alternatifs tels que TFLite via la bibliothèque de réécriture de modèles . Voir le lien vers la bibliothèque de réécriture de modèles pour des exemples de conversion des modèles Estimator et Keras.

Formateur générique

L'entraîneur générique permet aux développeurs d'utiliser n'importe quelle API de modèle TensorFlow avec le composant Trainer. En plus des estimateurs TensorFlow, les développeurs peuvent utiliser des modèles Keras ou des boucles d'entraînement personnalisées. Pour plus de détails, veuillez consulter la RFC pour le formateur générique .

Configuration du composant Trainer

Le code DSL de pipeline typique pour le formateur générique ressemblerait à ceci :

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer appelle un module de formation, qui est spécifié dans le paramètre module_file . Au lieu de trainer_fn , un run_fn est requis dans le fichier du module si GenericExecutor est spécifié dans custom_executor_spec . Le trainer_fn était responsable de la création du modèle. En plus de cela, run_fn doit également gérer la partie formation et générer le modèle formé à l'emplacement souhaité donné par FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Voici unexemple de fichier de module avec run_fn .

Notez que si le composant Transform n'est pas utilisé dans le pipeline, le formateur prendra directement les exemples de ExampleGen :

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Plus de détails sont disponibles dans la référence de l'API Trainer .