TFX pour mobile

Introduction

Ce guide montre comment Tensorflow Extended (TFX) peut créer et évaluer des modèles d'apprentissage automatique qui seront déployés sur l'appareil. TFX fournit désormais une prise en charge native de TFLite , ce qui permet d'effectuer une inférence très efficace sur les appareils mobiles.

Ce guide vous guide à travers les modifications qui peuvent être apportées à n'importe quel pipeline pour générer et évaluer des modèles TFLite. Nous fournissons ici un exemple complet, démontrant comment TFX peut former et évaluer des modèles TFLite formés à partir de l'ensemble de données MNIST . De plus, nous montrons comment le même pipeline peut être utilisé pour exporter simultanément à la fois le SavedModel standard basé sur Keras et celui de TFLite, permettant aux utilisateurs de comparer la qualité des deux.

Nous supposons que vous connaissez TFX, nos composants et nos pipelines. Si ce n'est pas le cas, veuillez consulter ce tutoriel .

Pas

Seules deux étapes sont nécessaires pour créer et évaluer un modèle TFLite dans TFX. La première étape consiste à invoquer le réécrivain TFLite dans le contexte de TFX Trainer pour convertir le modèle TensorFlow formé en modèle TFLite. La deuxième étape consiste à configurer l'évaluateur pour évaluer les modèles TFLite. Nous discutons maintenant de chacun tour à tour.

Invocation du réécrivain TFLite dans le formateur.

Le TFX Trainer s'attend à ce qu'un run_fn défini par l'utilisateur soit spécifié dans un fichier de module. Ce run_fn définit le modèle à entraîner, l'entraîne pour le nombre d'itérations spécifié et exporte le modèle entraîné.

Dans le reste de cette section, nous fournissons des extraits de code qui montrent les modifications requises pour appeler le réécrivain TFLite et exporter un modèle TFLite. Tout ce code se trouve dans le run_fn du module MNIST TFLite .

Comme le montre le code ci-dessous, nous devons d'abord créer une signature qui prend en entrée un Tensor pour chaque fonctionnalité. Notez qu’il s’agit d’une différence par rapport à la plupart des modèles existants dans TFX, qui prennent des protos tf.Example sérialisés en entrée.

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

Ensuite, le modèle Keras est enregistré en tant que SavedModel de la même manière qu'il l'est normalement.

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

Enfin, nous créons une instance du réécrivain TFLite ( tfrw ) et l'invoquons sur le SavedModel pour obtenir le modèle TFLite. Nous stockons ce modèle TFLite dans le serving_model_dir fourni par l'appelant du run_fn . De cette façon, le modèle TFLite est stocké à l'emplacement où tous les composants TFX en aval s'attendront à trouver le modèle.

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

Évaluation du modèle TFLite.

L' évaluateur TFX offre la possibilité d'analyser les modèles formés pour comprendre leur qualité sur un large éventail de mesures. En plus d'analyser les SavedModels, TFX Evaluator est désormais également capable d'analyser les modèles TFLite.

L'extrait de code suivant (reproduit à partir du pipeline MNIST ) montre comment configurer un évaluateur qui analyse un modèle TFLite.

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

Comme indiqué ci-dessus, le seul changement que nous devons apporter est de définir le champ model_type sur tf_lite . Aucune autre modification de configuration n’est requise pour analyser le modèle TFLite. Indépendamment du fait qu'un modèle TFLite ou un SavedModel soit analysé, la sortie de l' Evaluator aura exactement la même structure.

Cependant, veuillez noter que l'évaluateur suppose que le modèle TFLite est enregistré dans un fichier nommé tflite dans trainer_lite.outputs['model'].