モバイル向け TFX

導入

このガイドでは、Tensorflow Extended (TFX) がデバイス上にデプロイされる機械学習モデルを作成および評価する方法を説明します。 TFX はTFLiteのネイティブ サポートを提供するようになり、モバイル デバイス上で高効率の推論を実行できるようになります。

このガイドでは、TFLite モデルを生成および評価するためにパイプラインに加えることができる変更について説明します。ここで完全な例を提供し、TFX がMNISTデータセットからトレーニングされた TFLite モデルをどのようにトレーニングおよび評価できるかを示します。さらに、同じパイプラインを使用して標準の Keras ベースのSavedModelと TFLite の両方を同時にエクスポートし、ユーザーが 2 つの品質を比較できるようにする方法を示します。

読者は TFX、コンポーネント、パイプラインに精通していることを前提としています。そうでない場合は、このチュートリアルを参照してください。

ステップ

TFX で TFLite モデルを作成して評価するために必要な手順は 2 つだけです。最初のステップは、 TFX Trainerのコンテキスト内で TFLite リライターを呼び出して、トレーニングされた TensorFlow モデルを TFLite モデルに変換することです。 2 番目のステップは、TFLite モデルを評価するようにエバリュエーターを構成することです。それでは、それぞれについて順番に説明していきます。

Trainer 内で TFLite リライターを呼び出します。

TFX トレーナーは、ユーザー定義のrun_fnがモジュール ファイルで指定されることを期待します。このrun_fn 、トレーニングするモデルを定義し、指定された反復回数でトレーニングし、トレーニングされたモデルをエクスポートします。

このセクションの残りの部分では、TFLite リライターを呼び出して TFLite モデルをエクスポートするために必要な変更を示すコード スニペットを提供します。このコードはすべて、 MNIST TFLite モジュールrun_fnにあります。

以下のコードに示すように、最初にすべての特徴のTensor入力として受け取る署名を作成する必要があります。これは、シリアル化されたtf.Exampleプロトを入力として受け取る TFX の既存のほとんどのモデルとは異なることに注意してください。

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

次に、Keras モデルは通常と同じ方法で SavedModel として保存されます。

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

最後に、TFLite リライターのインスタンス ( tfrw ) を作成し、それを SavedModel 上で呼び出して TFLite モデルを取得します。この TFLite モデルは、 run_fnの呼び出し元によって提供されるserving_model_dirに保存されます。このようにして、TFLite モデルは、すべてのダウンストリーム TFX コンポーネントがモデルを見つけることを期待する場所に保存されます。

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

TFLite モデルの評価。

TFX Evaluator は、トレーニングされたモデルを分析して、幅広いメトリクスにわたる品質を理解する機能を提供します。 TFX Evaluator は、SavedModel の分析に加えて、TFLite モデルも分析できるようになりました。

次のコード スニペット ( MNIST パイプラインから再現) は、TFLite モデルを分析するエバリュエーターを構成する方法を示しています。

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

上に示したように、必要な変更は、 model_typeフィールドをtf_liteに設定することだけです。 TFLite モデルを分析するために他の構成を変更する必要はありません。 TFLite モデルまたは SavedModel のどちらが分析されるかに関係なく、 Evaluatorの出力はまったく同じ構造になります。

ただし、評価者は、TFLite モデルが Trainer_lite.outputs['model'] 内のtfliteという名前のファイルに保存されていることを前提としていることに注意してください。