訓練後の整数量子化

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

整数量子化は、32 ビット浮動小数点数(重みや活性化出力など)を最も近い 8 ビット固定小数点数に変換する最適化ストラテジーです。これにより、より小さなモデルが生成され、推論速度が増加するため、マイクロコントローラーといった性能の低いデバイスにとって貴重となります。このデータ形式は、Edge TPU などの整数のみのアクセラレータでも必要とされています。

このチュートリアルでは、MNIST モデルを新規にトレーニングし、それを TensorFlow Lite ファイルに変換して、トレーニング後量子化を使用して量子化します。最後に、変換されたモデルの精度を確認し、元の浮動小数点モデルと比較します。

モデルをどれくらい量子化するかについてのオプションには、実際いくつかあります。他のストラテジーでは、一部のデータが浮動小数点数のままとなることがありますが、このチュートリアルでは、すべての重みと活性化出力を 8 ビット整数データに変換する「全整数量子化」を実行します。

さまざまな量子化ストラテジーについての詳細は、TensorFlow Lite モデルの最適化をご覧ください。

セットアップ

入力テンソルと出力テンソルの両方を量子化するには、TensorFlow 2.3 で追加された API を使用する必要があります。

import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)

import tensorflow as tf
import numpy as np
print("TensorFlow version: ", tf.__version__)

TensorFlow モデルを生成する

MNIST データセットから、数字を分類する単純なモデルを構築します。

モデルのトレーニングは 5 エポックしか行わないため、時間はかかりません。およそ 98% の精度に達します。

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0

# Define the model architecture
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(
                  from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=5,
  validation_data=(test_images, test_labels)
)

TensorFlow Lite モデルに変換する

次に、TensorFlow Lite Converter を使用して、トレーニング済みのモデルを TensorFlow Lite 形式に変換し、様々な程度で量子化を適用できます。

量子化のバージョンの中には、一部のデータを浮動小数点数のフォーマットに残すものもあることに注意してください。そのため、以下のセクションでは、完全に int8 または unit8 データのモデルを得るまで、各オプションの量子化の量を増加しています。(各セクションのコードは、オプションごとにすべての量子化のステップを確認できるように、重複していることに注意してください。)

まず、量子化なしで変換されたモデルです。

converter = tf.lite.TFLiteConverter.from_keras_model(model)

tflite_model = converter.convert()

TensorFlow Lite モデルになってはいますが、すべてのパラメータデータには 32 ビット浮動小数点値が使用されています。

ダイナミックレンジ量子化による変換

では、デフォルトの optimizations フラグを有効にして、すべての固定パラメータ(重みなど)を量子化しましょう。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_model_quant = converter.convert()

重みが量子化されたためモデルは多少小さくなりましたが、他の変数データはまだ浮動小数点数フォーマットのままです。

浮動小数点数フォールバック量子化による変換

変数データ(モデル入力/出力やレイヤー間の中間データ)を量子化するには、RepresentativeDataset を指定する必要があります。これは、代表値を示すのに十分な大きさのある一連の入力データを提供するジェネレータ関数です。コンバータがすべての変数データのダイナミックレンジを推測できるようにします。(トレーニングや評価データセットとは異なり、このデータセットは一意である必要はありません。)複数の入力をサポートするために、それぞれの代表的なデータポイントはリストで、リストの要素はインデックスに従ってモデルに供給されます。

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_model_quant = converter.convert()

すべての重みと変数データが量子化されたため、元の TensorFlow Lite モデルにくらべてはるかに小さくなりました。

ただし、従来的に浮動小数点数モデルの入力テンソルと出力テンソルを使用するアプリケーションとの互換性を維持するために、TensorFlow Lite Converter は、モデルの入力テンソルと出力テンソルを浮動小数点数に残しています。

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

互換性を考慮すれば、大抵においては良いことではありますが、Edge TPU など、整数ベースの演算のみを実行するデバイスには対応していません。

さらに、上記のプロセスでは、TensorFlow Lite が演算用の量子化の実装を含まない場合、その演算を浮動小数点数フォーマットに残す可能性があります。このストラテジーでは、より小さく効率的なモデルを得られるように変換を完了することが可能ですが、やはり、整数のみのハードウェアには対応しません。(この MNIST モデルのすべての op には量子化された実装が含まれています。)

そこで、エンドツーエンドの整数限定モデルを確実に得られるよう、パラメータをいくつか追加する必要があります。

整数限定量子化による変換

入力テンソルと出力テンソルを量子化し、量子化できない演算に遭遇したらコンバーターがエラーをスローするようにするには、追加パラメータをいくつか使用して、モデルを変換し直します。

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

内部の量子化は上記と同じままですが、入力テンソルと出力テンソルが整数フォーマットになっているのがわかります。

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

これで、モデルの入力テンソルと出力テンソルに整数データを強いようする整数量子化モデルを得られました。Edge TPU などの整数限定ハードウェアに対応しています。

モデルをファイルとして保存する

モデルを他のデバイスにデプロイするには、.tflite ファイルが必要となります。そこで、変換されたモデルをファイルに保存して、以下の推論を実行する際に読み込んでみましょう。

import pathlib

tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"mnist_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)

TensorFlow Lite モデルを実行する

では、TensorFlow Lite Interpreter を使用して推論を実行し、モデルの精度を比較しましょう。

まず、特定のモデルと画像を使って推論を実行し、予測を返す関数が必要です。

# Helper function to run inference on a TFLite model
def run_tflite_model(tflite_file, test_image_indices):
  global test_images

  # Initialize the interpreter
  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  predictions = np.zeros((len(test_image_indices),), dtype=int)
  for i, test_image_index in enumerate(test_image_indices):
    test_image = test_images[test_image_index]
    test_label = test_labels[test_image_index]

    # Check if the input type is quantized, then rescale input data to uint8
    if input_details['dtype'] == np.uint8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]

    predictions[i] = output.argmax()

  return predictions

1つの画像に対してモデルを検証する

次に、浮動小数点数モデルと量子化モデルのパフォーマンスを比較します。

  • tflite_model_file は、浮動小数点数データを持つ元の TensorFlow Lite モデルです。
  • tflite_model_quant_file は、整数限定量子化を使用して変換した最後のモデルです(入力と出力に unit8 データを使用します)。

もう一つ、予測を出力する関数を作成しましょう。

import matplotlib.pylab as plt

# Change this to test a different image
test_image_index = 1

## Helper function to test the models on one image
def test_model(tflite_file, test_image_index, model_type):
  global test_labels

  predictions = run_tflite_model(tflite_file, [test_image_index])

  plt.imshow(test_images[test_image_index])
  template = model_type + " Model \n True:{true}, Predicted:{predict}"
  _ = plt.title(template.format(true= str(test_labels[test_image_index]), predict=str(predictions[0])))
  plt.grid(False)

では、浮動小数点数モデルをテストします。

test_model(tflite_model_file, test_image_index, model_type="Float")

今度は量子化されたモデル(uint8データを使用する)を検証します:

test_model(tflite_model_quant_file, test_image_index, model_type="Quantized")

モデルを評価する

このチュートリアルの冒頭で読み込んだすテスト画像をすべて使用して、両方のモデルを実行しましょう。

# Helper function to evaluate a TFLite model on all images
def evaluate_model(tflite_file, model_type):
  global test_images
  global test_labels

  test_image_indices = range(test_images.shape[0])
  predictions = run_tflite_model(tflite_file, test_image_indices)

  accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)

  print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
      model_type, accuracy, len(test_images)))

浮動小数点数モデルを評価します。

evaluate_model(tflite_model_file, model_type="Float")

uint8データを使用した完全に量子化されたモデルで評価を繰り返します:

evaluate_model(tflite_model_quant_file, model_type="Quantized")

これで、浮動小数点数モデルと比較し、ほぼ同じ制度を持つ整数量子化モデルが得られました。

他の量子化ストラテジーについての詳細は、TensorFlow Lite モデルの量子化をご覧ください。