![]() | ![]() | ![]() | ![]() |
概要
整数量子化は、32ビット浮動小数点数(重みやアクティブ化出力など)を最も近い8ビット固定小数点数に変換する最適化戦略です。この小さなモデルにおける結果とのような低電力デバイスのための価値があるの増加推論速度、マイクロコントローラ。このデータ形式は、整数のみのようなアクセラレータで必要とされるエッジTPU 。
このチュートリアルでは、あなたは、最初からMNISTモデルを訓練しますTensorflow Liteのファイルに変換し、使用してそれを量子化訓練後の量子化を。最後に、変換されたモデルの精度を確認し、元のフロートモデルと比較します。
モデルを量子化する量については、実際にはいくつかのオプションがあります。このチュートリアルでは、すべての重みとアクティベーション出力を8ビット整数データに変換する「完全整数量子化」を実行しますが、他の戦略では、浮動小数点にある程度のデータが残る場合があります。
様々な量子化戦略についての詳細については、読んTensorFlow Liteのモデルの最適化。
設定
入力テンソルと出力テンソルの両方を量子化するには、TensorFlowr2.3で追加されたAPIを使用する必要があります。
import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)
import tensorflow as tf
import numpy as np
assert float(tf.__version__[:3]) >= 2.3
TensorFlowモデルを生成する
私たちは、から分類番号に単純なモデルを構築しますMNISTデータセット。
モデルをわずか5エポックでトレーニングしているため、このトレーニングにはそれほど時間はかかりません。これは、約98%の精度でトレーニングされます。
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0
# Define the model architecture
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
epochs=5,
validation_data=(test_images, test_labels)
)
Epoch 1/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.2519 - accuracy: 0.9311 - val_loss: 0.1106 - val_accuracy: 0.9664 Epoch 2/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0984 - accuracy: 0.9724 - val_loss: 0.0828 - val_accuracy: 0.9743 Epoch 3/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0746 - accuracy: 0.9785 - val_loss: 0.0640 - val_accuracy: 0.9795 Epoch 4/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0620 - accuracy: 0.9814 - val_loss: 0.0620 - val_accuracy: 0.9793 Epoch 5/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0540 - accuracy: 0.9837 - val_loss: 0.0624 - val_accuracy: 0.9795 <keras.callbacks.History at 0x7fb44c988c90>
TensorFlowLiteモデルに変換する
今、あなたは使用してTensorFlow Liteの形式に訓練されたモデルを変換することができTFLiteConverter
APIを、そして量子化の様々な程度を適用します。
一部のバージョンの量子化では、一部のデータが浮動小数点形式のままになることに注意してください。したがって、次のセクションでは、完全にint8またはuint8データであるモデルを取得するまで、量子化の量を増やしながら各オプションを示します。 (各セクションでいくつかのコードを複製しているため、各オプションのすべての量子化ステップを確認できます。)
まず、量子化されていない変換モデルを次に示します。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
2021-10-30 12:04:56.623151: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/tmp3os2tr3n/assets 2021-10-30 12:04:57.031317: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-10-30 12:04:57.031355: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
現在はTensorFlowLiteモデルですが、すべてのパラメーターデータに32ビットのfloat値を使用しています。
ダイナミックレンジ量子化を使用して変換
今、聞かせてのは、デフォルトの有効optimizations
(このような重量など)、すべての固定パラメータを量子化するためのフラグを:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()
INFO:tensorflow:Assets written to: /tmp/tmpi7xibvaj/assets INFO:tensorflow:Assets written to: /tmp/tmpi7xibvaj/assets 2021-10-30 12:04:57.597982: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-10-30 12:04:57.598020: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
モデルは量子化された重みで少し小さくなりましたが、他の変数データはまだ浮動小数点形式です。
フロートフォールバック量子化を使用して変換
(このような層の間のモデル入力/出力および中間体として)可変データを量子化するために、あなたが提供する必要RepresentativeDataset
。これは、典型的な値を表すのに十分な大きさの入力データのセットを提供するジェネレーター関数です。これにより、コンバーターはすべての可変データのダイナミックレンジを推定できます。 (データセットは、トレーニングまたは評価データセットと比較して一意である必要はありません。)複数の入力をサポートするために、各代表的なデータポイントはリストであり、リスト内の要素はインデックスに従ってモデルに供給されます。
def representative_data_gen():
for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
# Model has only one input so each data point has one element.
yield [input_value]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
tflite_model_quant = converter.convert()
INFO:tensorflow:Assets written to: /tmp/tmp3gwloj7n/assets INFO:tensorflow:Assets written to: /tmp/tmp3gwloj7n/assets 2021-10-30 12:04:58.159142: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-10-30 12:04:58.159181: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0
これで、すべての重みと変数データが量子化され、モデルは元のTensorFlowLiteモデルと比較して大幅に小さくなりました。
ただし、従来フロートモデルの入力テンソルと出力テンソルを使用するアプリケーションとの互換性を維持するために、TensorFlow LiteConverterはモデルの入力テンソルと出力テンソルをfloatのままにします。
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
input: <class 'numpy.float32'> output: <class 'numpy.float32'>
これは通常、互換性には適していますが、EdgeTPUなどの整数ベースの操作のみを実行するデバイスとは互換性がありません。
さらに、TensorFlow Liteにその操作の量子化された実装が含まれていない場合、上記のプロセスはfloat形式で操作を残す可能性があります。この戦略により、変換を完了できるため、より小さく、より効率的なモデルが得られますが、整数のみのハードウェアとは互換性がありません。 (このMNISTモデルのすべての操作には、量子化された実装があります。)
したがって、エンドツーエンドの整数のみのモデルを確保するには、さらにいくつかのパラメーターが必要です...
整数のみの量子化を使用して変換する
入力テンソルと出力テンソルを量子化し、量子化できない操作が発生した場合にコンバーターがエラーをスローするようにするには、いくつかの追加パラメーターを使用してモデルを再度変換します。
def representative_data_gen():
for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
yield [input_value]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model_quant = converter.convert()
INFO:tensorflow:Assets written to: /tmp/tmp8ygc2_3y/assets INFO:tensorflow:Assets written to: /tmp/tmp8ygc2_3y/assets 2021-10-30 12:04:59.308505: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-10-30 12:04:59.308542: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.
内部量子化は上記と同じままですが、入力テンソルと出力テンソルが整数形式になっていることがわかります。
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
input: <class 'numpy.uint8'> output: <class 'numpy.uint8'>
今、あなたは整数量子化モデルを持っているモデルの入力と出力のテンソルのためのデータの整数用途、それはのような整数のみのハードウェアと互換性がありますので、エッジTPU 。
モデルをファイルとして保存
あなたは必要があります.tflite
他のデバイス上であなたのモデルを展開するために、ファイルを。それでは、変換されたモデルをファイルに保存してから、以下の推論を実行するときにそれらをロードしましょう。
import pathlib
tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"mnist_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)
24280
TensorFlowLiteモデルを実行する
今、私たちはTensorFlow Liteは使用して推論を実行しますInterpreter
モデル精度を比較します。
まず、特定のモデルと画像で推論を実行し、次に予測を返す関数が必要です。
# Helper function to run inference on a TFLite model
def run_tflite_model(tflite_file, test_image_indices):
global test_images
# Initialize the interpreter
interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
predictions = np.zeros((len(test_image_indices),), dtype=int)
for i, test_image_index in enumerate(test_image_indices):
test_image = test_images[test_image_index]
test_label = test_labels[test_image_index]
# Check if the input type is quantized, then rescale input data to uint8
if input_details['dtype'] == np.uint8:
input_scale, input_zero_point = input_details["quantization"]
test_image = test_image / input_scale + input_zero_point
test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
interpreter.set_tensor(input_details["index"], test_image)
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]
predictions[i] = output.argmax()
return predictions
1つの画像でモデルをテストします
次に、フロートモデルと量子化モデルのパフォーマンスを比較します。
-
tflite_model_file
浮動小数点データと元TensorFlow Liteのモデルです。 -
tflite_model_quant_file
、我々は、整数のみの量子化を用いて変換し、最後のモデルである(それは、入力および出力のためUINT8データを使用します)。
予測を出力する別の関数を作成しましょう。
import matplotlib.pylab as plt
# Change this to test a different image
test_image_index = 1
## Helper function to test the models on one image
def test_model(tflite_file, test_image_index, model_type):
global test_labels
predictions = run_tflite_model(tflite_file, [test_image_index])
plt.imshow(test_images[test_image_index])
template = model_type + " Model \n True:{true}, Predicted:{predict}"
_ = plt.title(template.format(true= str(test_labels[test_image_index]), predict=str(predictions[0])))
plt.grid(False)
次に、フロートモデルをテストします。
test_model(tflite_model_file, test_image_index, model_type="Float")
そして、量子化されたモデルをテストします。
test_model(tflite_model_quant_file, test_image_index, model_type="Quantized")
すべての画像でモデルを評価する
次に、このチュートリアルの最初にロードしたすべてのテストイメージを使用して、両方のモデルを実行しましょう。
# Helper function to evaluate a TFLite model on all images
def evaluate_model(tflite_file, model_type):
global test_images
global test_labels
test_image_indices = range(test_images.shape[0])
predictions = run_tflite_model(tflite_file, test_image_indices)
accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)
print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
model_type, accuracy, len(test_images)))
フロートモデルを評価します。
evaluate_model(tflite_model_file, model_type="Float")
Float model accuracy is 97.9500% (Number of test samples=10000)
量子化されたモデルを評価します。
evaluate_model(tflite_model_quant_file, model_type="Quantized")
Quantized model accuracy is 97.9300% (Number of test samples=10000)
これで、floatモデルと比較して、精度にほとんど差がない整数量子化モデルができました。
他の量子化戦略の詳細については、読んTensorFlow Liteのモデルの最適化。