TFX per dispositivi mobili

introduzione

Questa guida mostra come Tensorflow Extended (TFX) può creare e valutare modelli di machine learning che verranno distribuiti sul dispositivo. TFX ora fornisce supporto nativo per TFLite , che consente di eseguire inferenze altamente efficienti sui dispositivi mobili.

Questa guida illustra le modifiche che possono essere apportate a qualsiasi pipeline per generare e valutare i modelli TFLite. Forniamo un esempio completo qui , dimostrando come TFX può addestrare e valutare i modelli TFLite che vengono addestrati dal set di dati MNIST . Inoltre, mostriamo come la stessa pipeline può essere utilizzata per esportare simultaneamente sia il SavedModel standard basato su Keras sia quello TFLite, consentendo agli utenti di confrontare la qualità dei due.

Presumiamo che tu abbia familiarità con TFX, i nostri componenti e le nostre pipeline. In caso contrario, consulta questo tutorial .

Passi

Sono necessari solo due passaggi per creare e valutare un modello TFLite in TFX. Il primo passo è invocare il rewriter TFLite nel contesto di TFX Trainer per convertire il modello TensorFlow addestrato in uno TFLite. Il secondo passaggio è la configurazione dell'Evaluator per valutare i modelli TFLite. Ora discutiamo ciascuno a turno.

Invocare il rewriter di TFLite all'interno del Trainer.

TFX Trainer si aspetta che un run_fn definito dall'utente venga specificato in un file di modulo. Questo run_fn definisce il modello da addestrare, lo addestra per il numero di iterazioni specificato ed esporta il modello addestrato.

Nella parte restante di questa sezione vengono forniti frammenti di codice che mostrano le modifiche necessarie per richiamare il riscrittore TFLite ed esportare un modello TFLite. Tutto questo codice si trova nel run_fn del modulo MNIST TFLite .

Come mostrato nel codice seguente, dobbiamo prima creare una firma che prenda come input un Tensor per ogni caratteristica. Si noti che questa è una deviazione dalla maggior parte dei modelli esistenti in TFX, che accettano i prototipi serializzati di tf.Example come input.

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

Quindi il modello Keras viene salvato come SavedModel nello stesso modo in cui lo è normalmente.

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

Infine, creiamo un'istanza del rewriter TFLite ( tfrw ) e la invochiamo su SavedModel per ottenere il modello TFLite. Memorizziamo questo modello TFLite nella serving_model_dir fornita dal chiamante di run_fn . In questo modo, il modello TFLite viene archiviato nella posizione in cui tutti i componenti TFX a valle si aspettano di trovare il modello.

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

Valutazione del modello TFLite.

Il valutatore TFX offre la possibilità di analizzare modelli addestrati per comprenderne la qualità in un'ampia gamma di metriche. Oltre ad analizzare i modelli salvati, TFX Evaluator è ora in grado di analizzare anche i modelli TFLite.

Il frammento di codice seguente (riprodotto dalla pipeline MNIST ), mostra come configurare un valutatore che analizzi un modello TFLite.

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

Come mostrato sopra, l'unica modifica che dobbiamo apportare è impostare il campo model_type su tf_lite . Non sono necessarie altre modifiche alla configurazione per analizzare il modello TFLite. Indipendentemente dal fatto che venga analizzato un modello TFLite o un modello Evaluator , l'output dell'Evaluator avrà esattamente la stessa struttura.

Tuttavia, tieni presente che il valutatore presuppone che il modello TFLite sia salvato in un file chiamato tflite all'interno di trainer_lite.outputs['model'].