TFX for Mobile

序章

このガイドでは、Tensorflow Extended(TFX)がデバイスにデプロイされる機械学習モデルを作成および評価する方法を示します。 TFXは、 TFLiteのネイティブサポートを提供するようになりました。これにより、モバイルデバイスで非常に効率的な推論を実行できます。

このガイドでは、TFLiteモデルを生成および評価するためにパイプラインに加えることができる変更について説明します。ここでは完全な例を示し、TFXがMNISTデータセットからトレーニングされたTFLiteモデルをトレーニングおよび評価する方法を示します。さらに、同じパイプラインを使用して、標準のKerasベースのSavedModelとTFLiteモデルの両方を同時にエクスポートし、ユーザーが2つの品質を比較できるようにする方法を示します。

TFX、コンポーネント、パイプラインに精通していることを前提としています。そうでない場合は、このチュートリアルを参照してください。

ステップ

TFXでTFLiteモデルを作成および評価するには、2つの手順のみが必要です。最初のステップは、 TFX Trainerのコンテキスト内でTFLiteリライターを呼び出して、トレーニングされたTensorFlowモデルをTFLiteモデルに変換することです。 2番目のステップは、TFLiteモデルを評価するようにEvaluatorを構成することです。次に、それぞれについて順番に説明します。

トレーナー内でTFLiteリライターを呼び出します。

TFXトレーナーは、ユーザー定義のrun_fnがモジュールファイルで指定されていることを想定しています。このrun_fnは、トレーニングするモデルを定義し、指定された反復回数でモデルをトレーニングし、トレーニングされたモデルをエクスポートします。

このセクションの残りの部分では、TFLiteリライターを呼び出してTFLiteモデルをエクスポートするために必要な変更を示すコードスニペットを提供します。このコードはすべて、 MNISTTFLiteモジュールのrun_fn run_fnあります。

以下のコードに示すように、最初に、すべての機能のTensorを入力として受け取る署名を作成する必要があります。これは、シリアル化されたtf.Exampleprotosを入力として受け取るTFXのほとんどの既存のモデルからの逸脱であることに注意してください。

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

次に、Kerasモデルは通常と同じ方法でSavedModelとして保存されます。

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

最後に、TFLiteリライター( tfrw )のインスタンスを作成し、それをSavedModelで呼び出して、TFLiteモデルを取得します。このTFLiteモデルは、 serving_model_dirの呼び出し元によって提供されたrun_fnに格納されます。このように、TFLiteモデルは、すべてのダウンストリームTFXコンポーネントがモデルを見つけることを期待する場所に格納されます。

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

TFLiteモデルの評価。

TFX Evaluatorは、トレーニングされたモデルを分析して、さまざまなメトリックにわたる品質を理解する機能を提供します。 SavedModelsの分析に加えて、TFXEvaluatorはTFLiteモデルも分析できるようになりました。

次のコードスニペット( MNISTパイプラインから複製)は、TFLiteモデルを分析するエバリュエーターを構成する方法を示しています。

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

上に示したように、行う必要がある唯一の変更は、 model_typeフィールドをtf_liteに設定することです。 TFLiteモデルを分析するために、他の構成変更は必要ありません。 TFLiteモデルまたはSavedModelのどちらが分析されるかに関係なく、 Evaluatorの出力はまったく同じ構造になります。

ただし、エバリュエーターは、TFLiteモデルがtrainer_lite.outputs ['model']内のtfliteという名前のファイルに保存されていると想定していることに注意してください。