Keras の例による量子化認識トレーニング

TensorFlow.orgで表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

量子化認識トレーニングのエンドツーエンドの例へようこそ。

その他のページ

量子化認識トレーニングの紹介、および認識トレーニングを使用すべきかどうかの判定(サポート情報も含む)については、概要ページをご覧ください。

ユースケースに合った API を素早く特定するには(8 ビットのモデルの完全量子化を超えるユースケース)、総合ガイドをご覧ください。

要約

このチュートリアルでは、次について説明しています。

  1. MNIST の tf.keras モデルを最初からトレーニングする。
  2. 量子化認識トレーニング API を適用してモデルをファインチューニングし、精度を確認して量子化認識モデルをエクスポートする。
  3. このモデルを使用して、TFLite バックエンドのために実際に量子化されたモデルを作成する。
  4. TFLite および 1/4 のモデルの精度の永続性を確認する。モバイルでのレイテンシーのメリットを確認するには、TFLite アプリリポジトリ内の TFLite の例を試してみてください。

セットアップ

 pip uninstall -y tensorflow
 pip install -q tf-nightly
 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf

from tensorflow import keras

量子化認識トレーニングを使用せずに、MNIST のモデルをトレーニングする

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
1688/1688 [==============================] - 6s 2ms/step - loss: 0.5304 - accuracy: 0.8498 - val_loss: 0.1238 - val_accuracy: 0.9680
<tensorflow.python.keras.callbacks.History at 0x7f46023997f0>

量子化認識トレーニングを使用して、事前トレーニング済みモデルをクローンおよびファインチューニングする

モデルを定義する

量子化認識トレーニングをモデル全体に適用し、これをモデルの要約で確認します。すべてのレイヤーにプレフィックス "quant" が付いているはずです。

結果のモデルは量子化認識モデルですが、量子化はされていないので注意してください(例えば、重みは int8 ではなく float32 です)。次のセクションでは、量子化認識モデルから量子化モデルを作成する方法を示します。

総合ガイドでは、モデルの精度を改善するために、一部のレイヤーを量子化する方法をご覧いただけます。

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
=================================================================
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
_________________________________________________________________

モデルをベースラインに対してトレーニングおよび評価する

モデルを 1 エポックだけトレーニングした後のファインチューニングを実証するために、トレーニングデータのサブセットに対して、量子化認識トレーニングを使用してファインチューニングを行います。

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 1s 210ms/step - loss: 0.1565 - accuracy: 0.9544 - val_loss: 0.1604 - val_accuracy: 0.9700
<tensorflow.python.keras.callbacks.History at 0x7f469db3e208>

この例では、ベースラインと比較し、量子化認識トレーニング後のテスト精度の損失は、最小限あるいはゼロです。

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9642000198364258
Quant test accuracy: 0.9656000137329102

TFLite バックエンドの量子化モデルを作成する

この後に、重み int8 と活性化関数 uint8 を持つ、実際に量子化されたモデルが出来上がります。

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
WARNING:absl:Found untraced functions such as reshape_layer_call_fn, reshape_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, max_pooling2d_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpupf9tsyx/assets
INFO:tensorflow:Assets written to: /tmp/tmpupf9tsyx/assets

TF から TFLite への精度の永続性を確認する

テストデータセットで TFLite モデルを評価するヘルパー関数を定義します。

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

量子化されたモデルを評価し、TensorFlow の精度が TFLite バックエンドに持続されていることを確認します。

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9656
Quant TF test accuracy: 0.9656000137329102

量子化でモデルが 1/4 になることを確認する

浮動小数点数の TFLite モデルを作成して、量子化された TFLite モデルが 1/4 になっていることを確認します。

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmp/tmpx96l01vr/assets
INFO:tensorflow:Assets written to: /tmp/tmpx96l01vr/assets
Float model in Mb: 0.08058547973632812
Quantized model in Mb: 0.0234527587890625

結論

このチュートリアルでは、TensorFlow Model Optimization Toolkit API を使用して量子化認識モデルを作成し、TFLite バックエンドの量子化モデルを作成する方法を紹介しました。

MNIST のモデルでは、精度の違いを最小限に抑えながらモデルサイズを 1/4 に圧縮できることを示しましたた。モバイルでのレイテンシーのメリットを確認するには、TFLite アプリリポジトリ内の TFLite の例を試してみてください。

この新しい機能をぜひお試しください。リソースが制限される環境でのデプロイにおいて、特に重要となります。