Le composant Pipeline TFX Trainer

Le composant de pipeline Trainer TFX entraîne un modèle TensorFlow.

Formateur et TensorFlow

Formateur fait un usage intensif de la Python tensorflow API pour les modèles de formation.

Composant

Le formateur prend :

  • tf.Exemples utilisés pour la formation et l'évaluation.
  • Un fichier de module fourni par l'utilisateur qui définit la logique du formateur.
  • Protobuf définition des args de train et args eval.
  • (Facultatif) Un schéma de données créé par un composant de pipeline SchemaGen et éventuellement modifié par le développeur.
  • (Facultatif) graphe de transformation produit par un composant Transform en amont.
  • (Facultatif) modèles pré-entraînés utilisés pour des scénarios tels que le démarrage à chaud.
  • (Facultatif) hyperparamètres, qui seront transmis à la fonction du module utilisateur. Les détails de l'intégration avec Tuner se trouvent ici .

Le formateur émet : au moins un modèle pour l'inférence/la diffusion (généralement dans SavedModelFormat) et éventuellement un autre modèle pour eval (généralement un EvalSavedModel).

Nous fournissons le support des formats de modèles alternatifs tels que TFLite à travers le modèle Réécriture Library . Voir le lien vers la bibliothèque de réécriture de modèles pour des exemples de conversion des modèles Estimator et Keras.

Formateur générique

L'entraîneur générique permet aux développeurs d'utiliser n'importe quelle API de modèle TensorFlow avec le composant Trainer. En plus des estimateurs TensorFlow, les développeurs peuvent utiliser des modèles Keras ou des boucles d'entraînement personnalisées. Pour plus d' informations, s'il vous plaît voir le RFC pour l' entraîneur générique .

Configuration du composant Trainer

Le code DSL de pipeline typique pour le formateur générique ressemblerait à ceci :

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Formateur appelle un module de formation, qui est spécifié dans le module_file paramètre. Au lieu de trainer_fn , un run_fn est nécessaire dans le fichier du module si le GenericExecutor est spécifié dans le custom_executor_spec . Le trainer_fn était responsable de la création du modèle. En plus de cela, run_fn doit également gérer la partie de la formation et de la production du modèle formé à l'endroit désiré donné par FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Voici un exemple de run_fn fichier de module avec run_fn .

Notez que si le composant Transform n'est pas utilisé dans le pipeline, le formateur prendra directement les exemples de ExampleGen :

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Plus de détails sont disponibles dans la référence API formateur .