![]() |
![]() |
![]() |
![]() |
Overview
Welcome to an end-to-end example for quantization aware training.
Other pages
For an introduction to what quantization aware training is and to determine if you should use it (including what's supported), see the overview page.
To quickly find the APIs you need for your use case (beyond fully-quantizing a model with 8-bits), see the comprehensive guide.
Summary
In this tutorial, you will:
- Train a
tf.keras
model for MNIST from scratch. - Fine tune the model by applying the quantization aware training API, see the accuracy, and export a quantization aware model.
- Use the model to create an actually quantized model for the TFLite backend.
- See the persistence of accuracy in TFLite and a 4x smaller model. To see the latency benefits on mobile, try out the TFLite examples in the TFLite app repository.
Setup
pip install -q tensorflow
pip install -q tensorflow-model-optimization
import tempfile
import os
import tensorflow as tf
from tensorflow import keras
2021-07-09 11:11:07.247743: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Train a model for MNIST without quantization aware training
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
epochs=1,
validation_split=0.1,
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step 2021-07-09 11:11:09.168652: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1 2021-07-09 11:11:09.248794: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 2021-07-09 11:11:09.248840: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (kokoro-gcp-ubuntu-prod-2073537879): /proc/driver/nvidia/version does not exist 2021-07-09 11:11:09.249842: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2021-07-09 11:11:09.901900: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2) 2021-07-09 11:11:09.902433: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000179999 Hz 1688/1688 [==============================] - 8s 5ms/step - loss: 0.2936 - accuracy: 0.9166 - val_loss: 0.1126 - val_accuracy: 0.9687 <tensorflow.python.keras.callbacks.History at 0x7f0e50d95150>
Clone and fine-tune pre-trained model with quantization aware training
Define the model
You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by "quant".
Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.
In the comprehensive guide, you can see how to quantize some layers for model accuracy improvements.
import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)
# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
q_aware_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer (QuantizeLaye (None, 28, 28) 3 _________________________________________________________________ quant_reshape (QuantizeWrapp (None, 28, 28, 1) 1 _________________________________________________________________ quant_conv2d (QuantizeWrappe (None, 26, 26, 12) 147 _________________________________________________________________ quant_max_pooling2d (Quantiz (None, 13, 13, 12) 1 _________________________________________________________________ quant_flatten (QuantizeWrapp (None, 2028) 1 _________________________________________________________________ quant_dense (QuantizeWrapper (None, 10) 20295 ================================================================= Total params: 20,448 Trainable params: 20,410 Non-trainable params: 38 _________________________________________________________________
Train and evaluate the model against baseline
To demonstrate fine tuning after training the model for just an epoch, fine tune with quantization aware training on a subset of the training data.
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]
q_aware_model.fit(train_images_subset, train_labels_subset,
batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 1s 163ms/step - loss: 0.1440 - accuracy: 0.9578 - val_loss: 0.1361 - val_accuracy: 0.9800 <tensorflow.python.keras.callbacks.History at 0x7f0e5dd825d0>
For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline.
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9643999934196472 Quant test accuracy: 0.9660000205039978
Create quantized model for TFLite backend
After this, you have an actually quantized model with int8 weights and uint8 activations.
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
2021-07-09 11:11:21.592994: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:absl:Found untraced functions such as reshape_layer_call_fn, reshape_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, flatten_layer_call_fn while saving (showing 5 of 20). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpkj38vsde/assets INFO:tensorflow:Assets written to: /tmp/tmpkj38vsde/assets 2021-07-09 11:11:22.409404: I tensorflow/core/grappler/devices.cc:69] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 2021-07-09 11:11:22.409581: I tensorflow/core/grappler/clusters/single_machine.cc:357] Starting new session 2021-07-09 11:11:22.412690: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1144] Optimization results for grappler item: graph_to_optimize function_optimizer: function_optimizer did nothing. time = 0.007ms. function_optimizer: function_optimizer did nothing. time = 0.001ms. 2021-07-09 11:11:22.469961: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:345] Ignored output_format. 2021-07-09 11:11:22.470016: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:348] Ignored drop_control_dependency. 2021-07-09 11:11:22.473860: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
See persistence of accuracy from TF to TFLite
Define a helper function to evaluate the TF Lite model on the test dataset.
import numpy as np
def evaluate_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print('Evaluated on {n} results so far.'.format(n=i))
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
test_accuracy = evaluate_model(interpreter)
print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Quant TFLite test_accuracy: 0.966 Quant TF test accuracy: 0.9660000205039978
See 4x smaller model from quantization
You create a float TFLite model and then see that the quantized TFLite model is 4x smaller.
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()
# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')
with open(quant_file, 'wb') as f:
f.write(quantized_tflite_model)
with open(float_file, 'wb') as f:
f.write(float_tflite_model)
print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmp/tmpqcn0s43i/assets INFO:tensorflow:Assets written to: /tmp/tmpqcn0s43i/assets Float model in Mb: 0.08063507080078125 Quantized model in Mb: 0.0235137939453125 2021-07-09 11:11:26.423043: I tensorflow/core/grappler/devices.cc:69] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 2021-07-09 11:11:26.423220: I tensorflow/core/grappler/clusters/single_machine.cc:357] Starting new session 2021-07-09 11:11:26.425728: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1144] Optimization results for grappler item: graph_to_optimize function_optimizer: function_optimizer did nothing. time = 0.007ms. function_optimizer: function_optimizer did nothing. time = 0.001ms. 2021-07-09 11:11:26.456549: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:345] Ignored output_format. 2021-07-09 11:11:26.456588: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:348] Ignored drop_control_dependency.
Conclusion
In this tutorial, you saw how to create quantization aware models with the TensorFlow Model Optimization Toolkit API and then quantized models for the TFLite backend.
You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy difference. To see the latency benefits on mobile, try out the TFLite examples in the TFLite app repository.
We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.