Inspecting Quantization Errors with Quantization Debugger

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook See TF Hub model

Although full-integer quantization provides improved model size and latency, the quantized model won't always work as expected. It's usually expected for the model quality (e.g. accuracy, mAP, WER) to be slightly lower than the original float model. However, there are cases where the model quality can go below your expectation or generated completely wrong results.

When this problem happens, it's tricky and painful to spot the root cause of the quantization error, and it's even more difficult to fix the quantization error. To assist this model inspection process, quantization debugger can be used to identify problematic layers, and selective quantization can leave those problematic layers in float so that the model accuracy can be recovered at the cost of reduced benefit from quantization.

Quantization Debugger

Quantization debugger makes it possible to do quantization quality metric analysis in the existing model. Quantization debugger can automate processes for running model with a debug dataset, and collecting quantization quality metrics for each tensors.

Prerequisites

If you already have a pipeline to quantize a model, you have all necessary pieces to run quantization debugger!

  • Model to quantize
  • Representative dataset

In addition to model and data, you will need to use a data processing framework (e.g. pandas, Google Sheets) to analyze the exported results.

Setup

This section prepares libraries, MobileNet v3 model, and test dataset of 100 images.

# Quantization debugger is available from TensorFlow 2.7.0
pip uninstall -y tensorflow
pip install tf-nightly
pip install tensorflow_datasets --upgrade  # imagenet_v2 needs latest checksum
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Boilerplates and helpers

test_ds = ds.map(lambda data: (data['image'], data['label'] + 1)).batch(16)
loss, acc = model.evaluate(test_ds)
print(f'Top-5 accuracy (float): {acc * 100:.2f}%')
eval_tflite(quantized_model, ds)

We can see that the original model has a much higher top-5 accuracy for our small dataset, while the quantized model has a significant accuracy loss.

Step 1. Debugger preparation

Easiest way to use the quantization debugger is to provide tf.lite.TFLiteConverter that you have been using to quantize the model.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset(ds)

# my_debug_dataset should have the same format as my_representative_dataset
debugger = tf.lite.experimental.QuantizationDebugger(
    converter=converter, debug_dataset=representative_dataset(ds))

Step 2. Running the debugger and getting the results

When you call QuantizationDebugger.run(), the debugger will log differences between float tensors and quantized tensors for the same op location, and process them with given metrics.

debugger.run()

The processed metrics can be accessed with QuantizationDebugger.layer_statistics, or can be dumped to a text file in CSV format with QuantizationDebugger.layer_statistics_dump().

RESULTS_FILE = '/tmp/debugger_results.csv'
with open(RESULTS_FILE, 'w') as f:
  debugger.layer_statistics_dump(f)
head /tmp/debugger_results.csv

For each row in the dump, the op name and index comes first, followed by quantization parameters and error metrics (including user-defined error metrics, if any). The resulting CSV file can be used to pick problematic layers with large quantization error metrics.

With pandas or other data processing libraries, we can inspect detailed per-layer error metrics.

layer_stats = pd.read_csv(RESULTS_FILE)
layer_stats.head()

Step 3. Data analysis

There are various ways to analyze the resulting. First, let's add some useful metrics derived from the debugger's outputs. (scale means the quantization scale factor for each tensor.)

  • Range (256 / scale)
  • RMSE / scale (sqrt(mean_squared_error) / scale)

The RMSE / scale is close to 1 / sqrt(12) (~ 0.289) when quantized distribution is similar to the original float distribution, indicating a good quantized model. The larger the value is, it's more likely for the layer not being quantized well.

layer_stats['range'] = 255.0 * layer_stats['scale']
layer_stats['rmse/scale'] = layer_stats.apply(
    lambda row: np.sqrt(row['mean_squared_error']) / row['scale'], axis=1)
layer_stats[['op_name', 'range', 'rmse/scale']].head()
plt.figure(figsize=(15, 5))
ax1 = plt.subplot(121)
ax1.bar(np.arange(len(layer_stats)), layer_stats['range'])
ax1.set_ylabel('range')
ax2 = plt.subplot(122)
ax2.bar(np.arange(len(layer_stats)), layer_stats['rmse/scale'])
ax2.set_ylabel('rmse/scale')
plt.show()

There are many layers with wide ranges, and some layers that have high RMSE/scale values. Let's get the layers with high error metrics.

layer_stats[layer_stats['rmse/scale'] > 0.7][[
    'op_name', 'range', 'rmse/scale', 'tensor_name'
]]

With these layers, you can try selective quantization to see if not quantizing those layers improves model quality.

suspected_layers = list(
    layer_stats[layer_stats['rmse/scale'] > 0.7]['tensor_name'])

In addition to these, skipping quantization for the first few layers also helps improving quantized model's quality.

suspected_layers.extend(list(layer_stats[:5]['tensor_name']))

Selective Quantization

Selective quantization skips quantization for some nodes, so that the calculation can happen in the original floating-point domain. When correct layers are skipped, we can expect some model quality recovery at the cost of increased latency and model size.

However, if you're planning to run quantized models on integer-only accelerators (e.g. Hexagon DSP, EdgeTPU), selective quantization would cause fragmentation of the model and would result in slower inference latency mainly caused by data transfer cost between CPU and those accelerators. To prevent this, you can consider running quantization aware training to keep all the layers in integer while preserving the model accuracy.

Quantization debugger's option accepts denylisted_nodes and denylisted_ops options for skipping quantization for specific layers, or all instances of specific ops. Using suspected_layers we prepared from the previous step, we can use quantization debugger to get a selectively quantized model.

debug_options = tf.lite.experimental.QuantizationDebugOptions(
    denylisted_nodes=suspected_layers)
debugger = tf.lite.experimental.QuantizationDebugger(
    converter=converter,
    debug_dataset=representative_dataset(ds),
    debug_options=debug_options)
selective_quantized_model = debugger.get_nondebug_quantized_model()
eval_tflite(selective_quantized_model, ds)

The accuracy is still lower compared to the original float model, but we have notable improvement from the whole quantized model by skipping quantization for ~10 layers out of 111 layers.

You can also try to not quantized all ops in the same class. For example, to skip quantization for all mean ops, you can pass MEAN to denylisted_ops.

debug_options = tf.lite.experimental.QuantizationDebugOptions(
    denylisted_ops=['MEAN'])
debugger = tf.lite.experimental.QuantizationDebugger(
    converter=converter,
    debug_dataset=representative_dataset(ds),
    debug_options=debug_options)
selective_quantized_model = debugger.get_nondebug_quantized_model()
eval_tflite(selective_quantized_model, ds)

With these techniques, we are able to improve the quantized MobileNet V3 model accuracy. Next we'll explore advanced techniques to improve the model accuracy even more.

Advanced usages

Whith following features, you can further customize your debugging pipeline.

Custom metrics

By default, the quantization debugger emits five metrics for each float-quant difference: tensor size, standard deviation, mean error, max absolute error, and mean squared error. You can add more custom metrics by passing them to options. For each metrics, the result should be a single float value and the resulting metric will be an average of metrics from all examples.

  • layer_debug_metrics: calculate metric based on diff for each op outputs from float and quantized op outputs.
  • layer_direct_compare_metrics: rather than getting diff only, this will calculate metric based on raw float and quantized tensors, and its quantization parameters (scale, zero point)
  • model_debug_metrics: only used when float_model_(path|content) is passed to the debugger. In addition to the op-level metrics, final layer output is compared to the reference output from the original float model.
debug_options = tf.lite.experimental.QuantizationDebugOptions(
    layer_debug_metrics={
        'mean_abs_error': (lambda diff: np.mean(np.abs(diff)))
    },
    layer_direct_compare_metrics={
        'correlation':
            lambda f, q, s, zp: (np.corrcoef(f.flatten(),
                                             (q.flatten() - zp) / s)[0, 1])
    },
    model_debug_metrics={
        'argmax_accuracy': (lambda f, q: np.mean(np.argmax(f) == np.argmax(q)))
    })

debugger = tf.lite.experimental.QuantizationDebugger(
    converter=converter,
    debug_dataset=representative_dataset(ds),
    debug_options=debug_options)
debugger.run()
CUSTOM_RESULTS_FILE = '/tmp/debugger_results.csv'
with open(CUSTOM_RESULTS_FILE, 'w') as f:
  debugger.layer_statistics_dump(f)

custom_layer_stats = pd.read_csv(CUSTOM_RESULTS_FILE)
custom_layer_stats[['op_name', 'mean_abs_error', 'correlation']].tail()

The result of model_debug_metrics can be separately seen from debugger.model_statistics.

debugger.model_statistics

Using (internal) mlir_quantize API to access in-depth features

from tensorflow.lite.python import convert

Whole model verify mode

The default behavior for the debug model generation is per-layer verify. In this mode, the input for float and quantize op pair is from the same source (previous quantized op). Another mode is whole-model verify, where the float and quantize models are separated. This mode would be useful to observe how the error is being propagated down the model. To enable, enable_whole_model_verify=True to convert.mlir_quantize while generating the debug model manually.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.representative_dataset = representative_dataset(ds)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter._experimental_calibrate_only = True
calibrated_model = converter.convert()
# Note that enable_numeric_verify and enable_whole_model_verify are set.
quantized_model = convert.mlir_quantize(
    calibrated_model,
    enable_numeric_verify=True,
    enable_whole_model_verify=True)
debugger = tf.lite.experimental.QuantizationDebugger(
    quant_debug_model_content=quantized_model,
    debug_dataset=representative_dataset(ds))

Selective quantization from an already calibrated model

You can directly call convert.mlir_quantize to get the selective quantized model from already calibrated model. This would be particularly useful when you want to calibrate the model once, and experiment with various denylist combinations.

selective_quantized_model = convert.mlir_quantize(
    calibrated_model, denylisted_nodes=suspected_layers)
eval_tflite(selective_quantized_model, ds)