![]() |
![]() |
![]() |
![]() |
Welcome to the guide on Keras weights pruning for improving latency of on-device inference via XNNPACK.
This guide presents the usage of the newly introduced tfmot.sparsity.keras.PruningPolicy
API and demonstrates how it could be used for accelerating mostly convolutional models on modern CPUs using XNNPACK Sparse inference.
The guide covers the following steps of the model creation process:
- Build and train the dense baseline
- Fine-tune model with pruning
- Convert to TFLite
- On-device benchmark
The guide doesn't cover the best practices for the fine-tuning with pruning. For more detailed information on this topic, please check out our comprehensive guide.
Setup
pip install -q tensorflow
pip install -q tensorflow-model-optimization
import tempfile
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
%load_ext tensorboard
2022-12-14 12:32:39.432063: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:32:39.432165: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:32:39.432175: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Build and train the dense model
We build and train a simple baseline CNN for classification task on CIFAR10 dataset.
# Load CIFAR10 dataset.
(ds_train, ds_val, ds_test), ds_info = tfds.load(
'cifar10',
split=['train[:90%]', 'train[90%:]', 'test'],
as_supervised=True,
with_info=True,
)
# Normalize the input image so that each pixel value is between 0 and 1.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.image.convert_image_dtype(image, tf.float32), label
# Load the data in batches of 128 images.
batch_size = 128
def prepare_dataset(ds, buffer_size=None):
ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.cache()
if buffer_size:
ds = ds.shuffle(buffer_size)
ds = ds.batch(batch_size)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
ds_train = prepare_dataset(ds_train,
buffer_size=ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)
# Build the dense baseline model.
dense_model = keras.Sequential([
keras.layers.InputLayer(input_shape=(32, 32, 3)),
keras.layers.ZeroPadding2D(padding=1),
keras.layers.Conv2D(
filters=8,
kernel_size=(3, 3),
strides=(2, 2),
padding='valid'),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.ZeroPadding2D(padding=1),
keras.layers.DepthwiseConv2D(
kernel_size=(3, 3), strides=(2, 2), padding='valid'),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Compile and train the dense model for 10 epochs.
dense_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam',
metrics=['accuracy'])
dense_model.fit(
ds_train,
epochs=10,
validation_data=ds_val)
# Evaluate the dense model.
_, dense_model_accuracy = dense_model.evaluate(ds_test, verbose=0)
2022-12-14 12:32:41.944099: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 352/352 [==============================] - 14s 24ms/step - loss: 1.9439 - accuracy: 0.2778 - val_loss: 1.9882 - val_accuracy: 0.2444 Epoch 2/10 352/352 [==============================] - 6s 18ms/step - loss: 1.6829 - accuracy: 0.3774 - val_loss: 1.8590 - val_accuracy: 0.3166 Epoch 3/10 352/352 [==============================] - 6s 18ms/step - loss: 1.5914 - accuracy: 0.4163 - val_loss: 1.8210 - val_accuracy: 0.3432 Epoch 4/10 352/352 [==============================] - 6s 18ms/step - loss: 1.5381 - accuracy: 0.4406 - val_loss: 1.8508 - val_accuracy: 0.3128 Epoch 5/10 352/352 [==============================] - 6s 18ms/step - loss: 1.5026 - accuracy: 0.4527 - val_loss: 1.7111 - val_accuracy: 0.3870 Epoch 6/10 352/352 [==============================] - 6s 18ms/step - loss: 1.4753 - accuracy: 0.4648 - val_loss: 1.6088 - val_accuracy: 0.4136 Epoch 7/10 352/352 [==============================] - 6s 18ms/step - loss: 1.4568 - accuracy: 0.4734 - val_loss: 1.7585 - val_accuracy: 0.3810 Epoch 8/10 352/352 [==============================] - 6s 18ms/step - loss: 1.4368 - accuracy: 0.4793 - val_loss: 1.4221 - val_accuracy: 0.4836 Epoch 9/10 352/352 [==============================] - 6s 18ms/step - loss: 1.4178 - accuracy: 0.4888 - val_loss: 1.8803 - val_accuracy: 0.3684 Epoch 10/10 352/352 [==============================] - 6s 18ms/step - loss: 1.4023 - accuracy: 0.4925 - val_loss: 1.6101 - val_accuracy: 0.4168
Build the sparse model
Using the instructions from the comprehensive guide, we apply tfmot.sparsity.keras.prune_low_magnitude
function with parameters that target on-device acceleration via pruning i.e. tfmot.sparsity.keras.PruneForLatencyOnXNNPack
policy.
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Compute end step to finish pruning after after 5 epochs.
end_epoch = 5
num_iterations_per_epoch = len(ds_train)
end_step = num_iterations_per_epoch * end_epoch
# Define parameters for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.25,
final_sparsity=0.75,
begin_step=0,
end_step=end_step),
'pruning_policy': tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
}
# Try to apply pruning wrapper with pruning policy parameter.
try:
model_for_pruning = prune_low_magnitude(dense_model, **pruning_params)
except ValueError as e:
print(e)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
The call prune_low_magnitude
results in ValueError
with the message Could not find a GlobalAveragePooling2D layer with keepdims = True in all output branches
. The message indicates that the model isn't supported for pruning with policy tfmot.sparsity.keras.PruneForLatencyOnXNNPack
and specifically the layer GlobalAveragePooling2D
requires the parameter keepdims = True
. Let's fix that and reapply prune_low_magnitude
function.
fixed_dense_model = keras.Sequential([
keras.layers.InputLayer(input_shape=(32, 32, 3)),
keras.layers.ZeroPadding2D(padding=1),
keras.layers.Conv2D(
filters=8,
kernel_size=(3, 3),
strides=(2, 2),
padding='valid'),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.ZeroPadding2D(padding=1),
keras.layers.DepthwiseConv2D(
kernel_size=(3, 3), strides=(2, 2), padding='valid'),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
keras.layers.BatchNormalization(),
keras.layers.ReLU(),
keras.layers.GlobalAveragePooling2D(keepdims=True),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Use the pretrained model for pruning instead of training from scratch.
fixed_dense_model.set_weights(dense_model.get_weights())
# Try to reapply pruning wrapper.
model_for_pruning = prune_low_magnitude(fixed_dense_model, **pruning_params)
Invocation of prune_low_magnitude
has finished without any errors meaning that the model is fully supported for the tfmot.sparsity.keras.PruneForLatencyOnXNNPack
policy and can be accelerated using XNNPACK Sparse inference.
Fine-tune the sparse model
Following the pruning example, we fine-tune the sparse model using the weights of the dense model. We start fine-tuning of the model with 25% sparsity (25% of the weights are set to zero) and end with 75% sparsity.
logdir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam',
metrics=['accuracy'])
model_for_pruning.fit(
ds_train,
epochs=15,
validation_data=ds_val,
callbacks=callbacks)
# Evaluate the dense model.
_, pruned_model_accuracy = model_for_pruning.evaluate(ds_test, verbose=0)
print('Dense model test accuracy:', dense_model_accuracy)
print('Pruned model test accuracy:', pruned_model_accuracy)
Epoch 1/15 352/352 [==============================] - 10s 19ms/step - loss: 1.4103 - accuracy: 0.4889 - val_loss: 1.5619 - val_accuracy: 0.4310 Epoch 2/15 352/352 [==============================] - 6s 18ms/step - loss: 1.4365 - accuracy: 0.4812 - val_loss: 1.7069 - val_accuracy: 0.4032 Epoch 3/15 352/352 [==============================] - 6s 18ms/step - loss: 1.4512 - accuracy: 0.4780 - val_loss: 1.6233 - val_accuracy: 0.4114 Epoch 4/15 352/352 [==============================] - 7s 19ms/step - loss: 1.4495 - accuracy: 0.4774 - val_loss: 1.5601 - val_accuracy: 0.4342 Epoch 5/15 352/352 [==============================] - 7s 19ms/step - loss: 1.4360 - accuracy: 0.4830 - val_loss: 1.5823 - val_accuracy: 0.4072 Epoch 6/15 352/352 [==============================] - 7s 18ms/step - loss: 1.4230 - accuracy: 0.4885 - val_loss: 1.4669 - val_accuracy: 0.4564 Epoch 7/15 352/352 [==============================] - 6s 18ms/step - loss: 1.4140 - accuracy: 0.4916 - val_loss: 1.9092 - val_accuracy: 0.3582 Epoch 8/15 352/352 [==============================] - 6s 18ms/step - loss: 1.4054 - accuracy: 0.4966 - val_loss: 1.4328 - val_accuracy: 0.4820 Epoch 9/15 352/352 [==============================] - 6s 18ms/step - loss: 1.4010 - accuracy: 0.4986 - val_loss: 1.4243 - val_accuracy: 0.4848 Epoch 10/15 352/352 [==============================] - 6s 18ms/step - loss: 1.3951 - accuracy: 0.4994 - val_loss: 1.6031 - val_accuracy: 0.4106 Epoch 11/15 352/352 [==============================] - 6s 18ms/step - loss: 1.3905 - accuracy: 0.5025 - val_loss: 1.4120 - val_accuracy: 0.4918 Epoch 12/15 352/352 [==============================] - 6s 18ms/step - loss: 1.3882 - accuracy: 0.5001 - val_loss: 1.5359 - val_accuracy: 0.4378 Epoch 13/15 352/352 [==============================] - 6s 18ms/step - loss: 1.3821 - accuracy: 0.5046 - val_loss: 1.5810 - val_accuracy: 0.4448 Epoch 14/15 352/352 [==============================] - 6s 18ms/step - loss: 1.3805 - accuracy: 0.5042 - val_loss: 1.6039 - val_accuracy: 0.4370 Epoch 15/15 352/352 [==============================] - 6s 18ms/step - loss: 1.3774 - accuracy: 0.5085 - val_loss: 1.4667 - val_accuracy: 0.4734 Dense model test accuracy: 0.4189999997615814 Pruned model test accuracy: 0.48069998621940613
The logs show the progression of sparsity on a per-layer basis.
#docs_infra: no_execute
%tensorboard --logdir={logdir}
After the fine-tuning with pruning, test accuracy demonstrates a modest improvement (43% to 44%) compared to the dense model. Let's compare on-device latency using TFLite benchmark.
Model conversion and benchmarking
To convert the pruned model into TFLite, we need replace the PruneLowMagnitude
wrappers with original layers via the strip_pruning
function. Also, since the weights of the pruned model (model_for_pruning
) are mostly zeros, we may apply an optimization tf.lite.Optimize.EXPERIMENTAL_SPARSITY
to efficiently store the resulted TFLite model. This optimization flag is not required for the dense model.
converter = tf.lite.TFLiteConverter.from_keras_model(dense_model)
dense_tflite_model = converter.convert()
_, dense_tflite_file = tempfile.mkstemp('.tflite')
with open(dense_tflite_file, 'wb') as f:
f.write(dense_tflite_model)
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.EXPERIMENTAL_SPARSITY]
pruned_tflite_model = converter.convert()
_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
f.write(pruned_tflite_model)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 5 of 6). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpsgj89jpl/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpsgj89jpl/assets 2022-12-14 12:35:39.499415: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format. 2022-12-14 12:35:39.499471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency. WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 5 of 5). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp6u10wfgt/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp6u10wfgt/assets 2022-12-14 12:35:42.627814: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format. 2022-12-14 12:35:42.627867: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
Following the instructions of TFLite Model Benchmarking Tool, we build the tool, upload it to the Android device together with dense and pruned TFLite models, and benchmark both models on the device.
! adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/dense_model.tflite \
--use_xnnpack=true \
--num_runs=100 \
--num_threads=1
/bin/bash: adb: command not found
! adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/pruned_model.tflite \
--use_xnnpack=true \
--num_runs=100 \
--num_threads=1
/bin/bash: adb: command not found
Benchmarks on Pixel 4 resulted in average inference time of 17us for the dense model and 12us for the pruned model. The on-device benchmarks demonstrate a clear 5us or 30% improvements in latency even for such small models. In our experience, larger models based on MobileNetV3 or EfficientNet-lite show similar performance improvements. The speed-up varies based on the relative contribution of 1x1 convolutions to the overall model.
Conclusion
In this tutorial, we show how one may create sparse models for faster on-device performance using the new functionality introduced by the TF MOT API and XNNPack. These sparse models are smaller and faster than their dense counterparts while retaining or even surpassing their quality.
We encourage you to try this new capability which can be particularly important for deploying your models on device.