Magnitude-based weight pruning with Keras

View on TensorFlow.org Run in Google Colab View source on GitHub

Overview

Welcome to the tutorial for weight pruning, part of the TensorFlow Model Optimization toolkit.

What is weight pruning?

Weight pruning means literally that: eliminating unnecessary values in the weight tensor. We are practically setting neural network parameters' values to zero to remove low-weight connections between the layers of a neural network.

Why is useful?

Tensors with several values set to zero can be considered sparse. This results in important benefits: * Compression. Sparse tensors are amenable to compression by only keeping the non-zero values and their corresponding coordinates. * Speed. Sparse tensors allow us to skip otherwise unnecessary computations involving the zero values.

How does it work?

Our Keras-based weight pruning API is designed to iteratively remove connections based on their magnitude, during training. For more details on the usage of the API, please refer to the GitHub page.

In this tutorial, we'll walk you through an end-to-end example of using the weight pruning API on a simple MNIST model. We will show that by simply using a generic file compression algorithm (e.g. zip) the Keras model will be reduced in size, and that this size reduction persists when converted to a Tensorflow Lite format.

Two things worth clarifying: * The technique and API are not TensorFlow Lite specific --we just show its application on the TensorFlow Lite backend, as it covers size-sensitive use-cases. * By itself, a sparse model will not be faster to execute. It just enables backends with such capability. In the near future, however, TensorFlow Lite will take advantage of the sparsity to speed up computations.

To recap, in the tutorial we will: 1. Train a MNIST model with Keras from scratch. 2. Train a pruned MNIST with the pruning API. 3. Compare the size of the pruned model and the non-pruned one after compression. 4. Convert the pruned model to Tensorflow Lite format and verify that accuracy persists. 5. Show how the pruned model works with other optimization techniques, like post-training quantization.

Setup

To use the pruning API, install the tensorflow-model-optimization package. See the TensorFlow model optimization repo for compatible API versions.

Since you will train a few models in this tutorial, install the tensorflow-gpu package to speed up things. Enable the GPU with: Runtime > Change runtime type > Hardware accelator and make sure GPU is selected.

! pip uninstall -y tensorflow
! pip uninstall -y tf-nightly
! pip install -q -U tensorflow-gpu==1.14.0

! pip install -q tensorflow-model-optimization
WARNING: Skipping tensorflow as it is not installed.
WARNING: Skipping tf-nightly as it is not installed.
%load_ext tensorboard
import tensorboard
import tensorflow as tf
tf.enable_eager_execution()

import tempfile
import zipfile
import os
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])

Prepare the training data

batch_size = 128
num_classes = 10
epochs = 10

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

if tf.keras.backend.image_data_format() == 'channels_first':
  x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
  x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
  input_shape = (1, img_rows, img_cols)
else:
  x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
  x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
  input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples

Train a MNIST model without pruning

Build the MNIST model

l = tf.keras.layers

model = tf.keras.Sequential([
    l.Conv2D(
        32, 5, padding='same', activation='relu', input_shape=input_shape),
    l.MaxPooling2D((2, 2), (2, 2), padding='same'),
    l.BatchNormalization(),
    l.Conv2D(64, 5, padding='same', activation='relu'),
    l.MaxPooling2D((2, 2), (2, 2), padding='same'),
    l.Flatten(),
    l.Dense(1024, activation='relu'),
    l.Dropout(0.4),
    l.Dense(num_classes, activation='softmax')
])

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 32)        128       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        51264     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0         
_________________________________________________________________
dense (Dense)                (None, 1024)              3212288   
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                10250     
=================================================================
Total params: 3,274,762
Trainable params: 3,274,698
Non-trainable params: 64
_________________________________________________________________

Train the model to reach an accuracy >99%

Load TensorBoard to monitor the training process

logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)
Writing training logs to /tmpfs/tmp/tmp4hh8qg38
%tensorboard --logdir={logdir}
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]

model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer='adam',
    metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          callbacks=callbacks,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Train on 60000 samples, validate on 10000 samples
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 1/10
60000/60000 [==============================] - 82s 1ms/sample - loss: 0.2115 - acc: 0.9498 - val_loss: 0.1037 - val_acc: 0.9841
Epoch 2/10
60000/60000 [==============================] - 80s 1ms/sample - loss: 0.0468 - acc: 0.9858 - val_loss: 0.0314 - val_acc: 0.9894
Epoch 3/10
60000/60000 [==============================] - 81s 1ms/sample - loss: 0.0323 - acc: 0.9901 - val_loss: 0.0463 - val_acc: 0.9842
Epoch 4/10
60000/60000 [==============================] - 82s 1ms/sample - loss: 0.0259 - acc: 0.9921 - val_loss: 0.0254 - val_acc: 0.9918
Epoch 5/10
60000/60000 [==============================] - 81s 1ms/sample - loss: 0.0201 - acc: 0.9935 - val_loss: 0.0235 - val_acc: 0.9928
Epoch 6/10
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0184 - acc: 0.9943 - val_loss: 0.0336 - val_acc: 0.9906
Epoch 7/10
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0153 - acc: 0.9951 - val_loss: 0.0230 - val_acc: 0.9930
Epoch 8/10
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0157 - acc: 0.9954 - val_loss: 0.0292 - val_acc: 0.9922
Epoch 9/10
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0148 - acc: 0.9954 - val_loss: 0.0346 - val_acc: 0.9907
Epoch 10/10
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0124 - acc: 0.9961 - val_loss: 0.0313 - val_acc: 0.9915
Test loss: 0.03128007990108817
Test accuracy: 0.9915

Save the original model for size comparison later

# Backend agnostic way to save/restore models
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Saving model to:  /tmpfs/tmp/tmpz2zxnsmm.h5

Train a pruned MNIST

We provide a prune_low_magnitude() API to train models with removed connections. The Keras-based API can be applied at the level of individual layers, or the entire model. We will show you the usage of both in the following sections.

At a high level, the technique works by iteratively removing (i.e. zeroing out) connections between layers, given an schedule and a target sparsity.

For example, a typical configuration will target a 75% sparsity, by pruning connections every 100 steps (aka epochs), starting from step 2,000. For more details on the possible configurations, please refer to the github documentation.

Build a pruned model layer by layer

In this example, we show how to use the API at the level of layers, and build a pruned MNIST solver model.

In this case, the prune_low_magnitude() receives as parameter the Keras layer whose weights we want pruned.

This function requires a pruning params which configures the pruning algorithm during training. Please refer to our github page for detailed documentation. The parameter used here means:

  1. Sparsity. PolynomialDecay is used across the whole training process. We start at the sparsity level 50% and gradually train the model to reach 90% sparsity. X% sparsity means that X% of the weight tensor is going to be pruned away.
  2. Schedule. Connections are pruned starting from step 2000 to the end of training, and runs every 100 steps. The reasoning behind this is that we want to train the model without pruning for a few epochs to reach a certain accuracy, to aid convergence. Furthermore, we give the model some time to recover after each pruning step, so pruning does not happen on every step. We set the pruning frequency to 100.
from tensorflow_model_optimization.sparsity import keras as sparsity

To demonstrate how to save and restore a pruned keras model, in the following example we first train the model for 10 epochs, save it to disk, and finally restore and continue training for 2 epochs. With gradual sparsity, four important parameters are begin_sparsity, final_sparsity, begin_step and end_step. The first three are straight forward. Let's calculate the end step given the number of train example, batch size, and the total epochs to train.

import numpy as np

epochs = 12
num_train_samples = x_train.shape[0]
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
print('End step: ' + str(end_step))
End step: 5628
pruning_params = {
      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                   final_sparsity=0.90,
                                                   begin_step=2000,
                                                   end_step=end_step,
                                                   frequency=100)
}

pruned_model = tf.keras.Sequential([
    sparsity.prune_low_magnitude(
        l.Conv2D(32, 5, padding='same', activation='relu'),
        input_shape=input_shape,
        **pruning_params),
    l.MaxPooling2D((2, 2), (2, 2), padding='same'),
    l.BatchNormalization(),
    sparsity.prune_low_magnitude(
        l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),
    l.MaxPooling2D((2, 2), (2, 2), padding='same'),
    l.Flatten(),
    sparsity.prune_low_magnitude(l.Dense(1024, activation='relu'),
                                 **pruning_params),
    l.Dropout(0.4),
    sparsity.prune_low_magnitude(l.Dense(num_classes, activation='softmax'),
                                 **pruning_params)
])

pruned_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py:240: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_conv2d_2 (None, 28, 28, 32)        1634      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
prune_low_magnitude_conv2d_3 (None, 14, 14, 64)        102466    
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 3136)              0         
_________________________________________________________________
prune_low_magnitude_dense_2  (None, 1024)              6423554   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1024)              0         
_________________________________________________________________
prune_low_magnitude_dense_3  (None, 10)                20492     
=================================================================
Total params: 6,548,274
Trainable params: 3,274,698
Non-trainable params: 3,273,576
_________________________________________________________________

Load Tensorboard

logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)
Writing training logs to /tmpfs/tmp/tmpb62y13eb
%tensorboard --logdir={logdir}

Train the model

Start pruning from step 2000 when accuracy >98%

pruned_model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer='adam',
    metrics=['accuracy'])

# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
    sparsity.UpdatePruningStep(),
    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]

pruned_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=10,
          verbose=1,
          callbacks=callbacks,
          validation_data=(x_test, y_test))

score = pruned_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.2195 - acc: 0.9447INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 86s 1ms/sample - loss: 0.2192 - acc: 0.9448 - val_loss: 0.1579 - val_acc: 0.9862
Epoch 2/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0461 - acc: 0.9851INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0460 - acc: 0.9851 - val_loss: 0.0239 - val_acc: 0.9921
Epoch 3/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0308 - acc: 0.9904INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 85s 1ms/sample - loss: 0.0308 - acc: 0.9904 - val_loss: 0.0341 - val_acc: 0.9889
Epoch 4/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0240 - acc: 0.9926INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 86s 1ms/sample - loss: 0.0240 - acc: 0.9926 - val_loss: 0.0290 - val_acc: 0.9899
Epoch 5/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0188 - acc: 0.9940INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 87s 1ms/sample - loss: 0.0188 - acc: 0.9940 - val_loss: 0.0263 - val_acc: 0.9905
Epoch 6/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0153 - acc: 0.9954INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 87s 1ms/sample - loss: 0.0154 - acc: 0.9954 - val_loss: 0.0237 - val_acc: 0.9917
Epoch 7/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0143 - acc: 0.9953INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0143 - acc: 0.9953 - val_loss: 0.0180 - val_acc: 0.9941
Epoch 8/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0135 - acc: 0.9953INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 84s 1ms/sample - loss: 0.0134 - acc: 0.9953 - val_loss: 0.0221 - val_acc: 0.9937
Epoch 9/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0135 - acc: 0.9958INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0135 - acc: 0.9958 - val_loss: 0.0242 - val_acc: 0.9916
Epoch 10/10
59904/60000 [============================>.] - ETA: 0s - loss: 0.0124 - acc: 0.9961INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2/threshold_0/threshold instead.
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0125 - acc: 0.9961 - val_loss: 0.0239 - val_acc: 0.9926
Test loss: 0.02387362446846637
Test accuracy: 0.9926

Save and restore the pruned model

Continue training for two epochs:

_, checkpoint_file = tempfile.mkstemp('.h5')
print('Saving pruned model to: ', checkpoint_file)
# saved_model() sets include_optimizer to True by default. Spelling it out here
# to highlight.
tf.keras.models.save_model(pruned_model, checkpoint_file, include_optimizer=True)

with sparsity.prune_scope():
  restored_model = tf.keras.models.load_model(checkpoint_file)

restored_model.fit(x_train, y_train,
                   batch_size=batch_size,
                   epochs=2,
                   verbose=1,
                   callbacks=callbacks,
                   validation_data=(x_test, y_test))

score = restored_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Saving pruned model to:  /tmpfs/tmp/tmp_nnmiay1.h5
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
59904/60000 [============================>.] - ETA: 0s - loss: 0.0107 - acc: 0.9964INFO:tensorflow:Summary name prune_low_magnitude_dense_2_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2_1/mask_0/sparsity instead.
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0107 - acc: 0.9964 - val_loss: 0.0202 - val_acc: 0.9934
Epoch 2/2
59904/60000 [============================>.] - ETA: 0s - loss: 0.0080 - acc: 0.9974INFO:tensorflow:Summary name prune_low_magnitude_dense_2_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_2_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_2_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_3_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_2_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_2_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_3_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_3_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_3_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_2_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_2_1/mask_0/sparsity instead.
60000/60000 [==============================] - 81s 1ms/sample - loss: 0.0080 - acc: 0.9974 - val_loss: 0.0190 - val_acc: 0.9946
Test loss: 0.0190146552314136
Test accuracy: 0.9946

In the example above, a few things to note are:

  • When saving the model, include_optimizer must be set to True. We need to preserve the state of the optimizer across training sessions for pruning to work properly.
  • When loading the pruned model, you need the prune_scope() for deseriazliation.

Strip the pruning wrappers from the pruned model before export for serving

Before exporting a serving model, you'd need to call the strip_pruning API to strip the pruning wrappers from the model, as it's only needed for training.

final_model = sparsity.strip_pruning(pruned_model)
final_model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_2 (Conv2D)            (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 64)        51264     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 3136)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              3212288   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                10250     
=================================================================
Total params: 3,274,762
Trainable params: 3,274,698
Non-trainable params: 64
_________________________________________________________________
_, pruned_keras_file = tempfile.mkstemp('.h5')
print('Saving pruned model to: ', pruned_keras_file)

# No need to save the optimizer with the graph for serving.
tf.keras.models.save_model(final_model, pruned_keras_file, include_optimizer=False)
Saving pruned model to:  /tmpfs/tmp/tmpcerod3mg.h5

Compare the size of the unpruned vs. pruned model after compression

_, zip1 = tempfile.mkstemp('.zip') 
with zipfile.ZipFile(zip1, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(keras_file)
print("Size of the unpruned model before compression: %.2f Mb" % 
      (os.path.getsize(keras_file) / float(2**20)))
print("Size of the unpruned model after compression: %.2f Mb" % 
      (os.path.getsize(zip1) / float(2**20)))

_, zip2 = tempfile.mkstemp('.zip') 
with zipfile.ZipFile(zip2, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(pruned_keras_file)
print("Size of the pruned model before compression: %.2f Mb" % 
      (os.path.getsize(pruned_keras_file) / float(2**20)))
print("Size of the pruned model after compression: %.2f Mb" % 
      (os.path.getsize(zip2) / float(2**20)))

Size of the unpruned model before compression: 12.52 Mb
Size of the unpruned model after compression: 11.59 Mb
Size of the pruned model before compression: 12.52 Mb
Size of the pruned model after compression: 2.51 Mb

Prune a whole model

The prune_low_magnitude function can also be applied to the entire Keras model.

In this case, the algorithm will be applied to all layers that are ameanable to weight pruning (that the API knows about). Layers that the API knows are not ameanable to weight pruning will be ignored, and unknown layers to the API will cause an error.

If your model has layers that the API does not know how to prune their weights, but are perfectly fine to leave "un-pruned", then just apply the API in a per-layer basis.

Regarding pruning configuration, the same settings apply to all prunable layers in the model.

Also noteworthy is that pruning doesn't preserve the optimizer associated with the original model. As a result, it is necessary to re-compile the pruned model with a new optimizer.

Before we move forward with the example, lets address the common use case where you may already have a serialized pre-trained Keras model, which you would like to apply weight pruning on. We will take the original MNIST model trained previously to show how this works. In this case, you start by loading the model into memory like this:

# Load the serialized model
loaded_model = tf.keras.models.load_model(keras_file)
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.

Then you can prune the model loaded and compile the pruned model for training. In this case training will restart from step 0. Given the model we loadded already reached a satisfactory accuracy, we can start pruning immediately. As a result, we set the begin_step to 0 here, and only train for another four epochs.

epochs = 4
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
print(end_step)

new_pruning_params = {
      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                   final_sparsity=0.90,
                                                   begin_step=0,
                                                   end_step=end_step,
                                                   frequency=100)
}

new_pruned_model = sparsity.prune_low_magnitude(model, **new_pruning_params)
new_pruned_model.summary()

new_pruned_model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer='adam',
    metrics=['accuracy'])
1876
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_conv2d ( (None, 28, 28, 32)        1634      
_________________________________________________________________
prune_low_magnitude_max_pool (None, 14, 14, 32)        1         
_________________________________________________________________
prune_low_magnitude_batch_no (None, 14, 14, 32)        129       
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 14, 14, 64)        102466    
_________________________________________________________________
prune_low_magnitude_max_pool (None, 7, 7, 64)          1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 3136)              1         
_________________________________________________________________
prune_low_magnitude_dense (P (None, 1024)              6423554   
_________________________________________________________________
prune_low_magnitude_dropout  (None, 1024)              1         
_________________________________________________________________
prune_low_magnitude_dense_1  (None, 10)                20492     
=================================================================
Total params: 6,548,279
Trainable params: 3,274,698
Non-trainable params: 3,273,581
_________________________________________________________________

Load tensorboard

logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)
Writing training logs to /tmpfs/tmp/tmps0kryv1c
%tensorboard --logdir={logdir}

Train the model for another four epochs

# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
    sparsity.UpdatePruningStep(),
    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]

new_pruned_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          callbacks=callbacks,
          validation_data=(x_test, y_test))

score = new_pruned_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Train on 60000 samples, validate on 10000 samples
Epoch 1/4
59904/60000 [============================>.] - ETA: 0s - loss: 0.0127 - acc: 0.9956INFO:tensorflow:Summary name prune_low_magnitude_dense_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/mask:0/sparsity is illegal; using prune_low_magnitude_dense/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/threshold:0/threshold is illegal; using prune_low_magnitude_dense/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d/threshold_0/threshold instead.
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0127 - acc: 0.9956 - val_loss: 0.0241 - val_acc: 0.9922
Epoch 2/4
59904/60000 [============================>.] - ETA: 0s - loss: 0.0176 - acc: 0.9943INFO:tensorflow:Summary name prune_low_magnitude_dense_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/mask:0/sparsity is illegal; using prune_low_magnitude_dense/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/threshold:0/threshold is illegal; using prune_low_magnitude_dense/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d/threshold_0/threshold instead.
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0175 - acc: 0.9943 - val_loss: 0.0287 - val_acc: 0.9911
Epoch 3/4
59904/60000 [============================>.] - ETA: 0s - loss: 0.0212 - acc: 0.9932INFO:tensorflow:Summary name prune_low_magnitude_dense_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/mask:0/sparsity is illegal; using prune_low_magnitude_dense/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/threshold:0/threshold is illegal; using prune_low_magnitude_dense/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d/threshold_0/threshold instead.
60000/60000 [==============================] - 83s 1ms/sample - loss: 0.0212 - acc: 0.9932 - val_loss: 0.0361 - val_acc: 0.9889
Epoch 4/4
59904/60000 [============================>.] - ETA: 0s - loss: 0.0168 - acc: 0.9947INFO:tensorflow:Summary name prune_low_magnitude_dense_1/threshold:0/threshold is illegal; using prune_low_magnitude_dense_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/mask:0/sparsity is illegal; using prune_low_magnitude_dense/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense/threshold:0/threshold is illegal; using prune_low_magnitude_dense/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_1/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_1/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_1/mask:0/sparsity is illegal; using prune_low_magnitude_dense_1/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d/threshold_0/threshold instead.
60000/60000 [==============================] - 82s 1ms/sample - loss: 0.0168 - acc: 0.9947 - val_loss: 0.0221 - val_acc: 0.9929
Test loss: 0.022073098729321644
Test accuracy: 0.9929

Export the pruned model for serving

final_model = sparsity.strip_pruning(pruned_model)
final_model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_2 (Conv2D)            (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 64)        51264     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 3136)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              3212288   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                10250     
=================================================================
Total params: 3,274,762
Trainable params: 3,274,698
Non-trainable params: 64
_________________________________________________________________
_, new_pruned_keras_file = tempfile.mkstemp('.h5')
print('Saving pruned model to: ', new_pruned_keras_file)
tf.keras.models.save_model(final_model, new_pruned_keras_file, 
                        include_optimizer=False)
Saving pruned model to:  /tmpfs/tmp/tmpb9ons69j.h5

The model size after compression is the same as the one pruned layer-by-layer

_, zip3 = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zip3, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(new_pruned_keras_file)
print("Size of the pruned model before compression: %.2f Mb" 
      % (os.path.getsize(new_pruned_keras_file) / float(2**20)))
print("Size of the pruned model after compression: %.2f Mb" 
      % (os.path.getsize(zip3) / float(2**20)))
Size of the pruned model before compression: 12.52 Mb
Size of the pruned model after compression: 2.51 Mb

Convert to TensorFlow Lite

Finally, you can convert the pruned model to a format that's runnable on your targeting backend. Tensorflow Lite is an example format you can use to deploy to mobile devices. To convert to a Tensorflow Lite graph, you need to use the TFLiteConverter as below:

Convert the model with TFLiteConverter

tflite_model_file = '/tmp/sparse_mnist.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model_file(pruned_keras_file)
tflite_model = converter.convert()
with open(tflite_model_file, 'wb') as f:
  f.write(tflite_model)
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
INFO:tensorflow:Converted 12 variables to const ops.

Size of the TensorFlow Lite model after compression

_, zip_tflite = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zip_tflite, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(tflite_model_file)
print("Size of the tflite model before compression: %.2f Mb" 
      % (os.path.getsize(tflite_model_file) / float(2**20)))
print("Size of the tflite model after compression: %.2f Mb" 
      % (os.path.getsize(zip_tflite) / float(2**20)))
Size of the tflite model before compression: 12.49 Mb
Size of the tflite model after compression: 2.43 Mb

Evaluate the accuracy of the TensorFlow Lite model

import numpy as np

interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

def eval_model(interpreter, x_test, y_test):
  total_seen = 0
  num_correct = 0

  for img, label in zip(x_test, y_test):
    inp = img.reshape((1, 28, 28, 1))
    total_seen += 1
    interpreter.set_tensor(input_index, inp)
    interpreter.invoke()
    predictions = interpreter.get_tensor(output_index)
    if np.argmax(predictions) == np.argmax(label):
      num_correct += 1

    if total_seen % 1000 == 0:
        print("Accuracy after %i images: %f" %
              (total_seen, float(num_correct) / float(total_seen)))

  return float(num_correct) / float(total_seen)

print(eval_model(interpreter, x_test, y_test))
Accuracy after 1000 images: 0.991000
Accuracy after 2000 images: 0.989000
Accuracy after 3000 images: 0.987000
Accuracy after 4000 images: 0.988250
Accuracy after 5000 images: 0.988400
Accuracy after 6000 images: 0.990167
Accuracy after 7000 images: 0.990857
Accuracy after 8000 images: 0.991875
Accuracy after 9000 images: 0.992556
Accuracy after 10000 images: 0.992600
0.9926

Post-training quantize the TensorFlow Lite model

You can combine pruning with other optimization techniques like post training quantization. As a recap, post-training quantization converts weights to 8 bit precision as part of model conversion from keras model to TFLite's flat buffer, resulting in a 4x reduction in the model size.

In the following example, we take the pruned keras model, convert it with post-training quantization, check the size reduction and validate its accuracy.

converter = tf.lite.TFLiteConverter.from_keras_model_file(pruned_keras_file)

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]

tflite_quant_model = converter.convert()

tflite_quant_model_file = '/tmp/sparse_mnist_quant.tflite'
with open(tflite_quant_model_file, 'wb') as f:
  f.write(tflite_quant_model)
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
INFO:tensorflow:Converted 12 variables to const ops.
_, zip_tflite = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zip_tflite, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(tflite_quant_model_file)
print("Size of the tflite model before compression: %.2f Mb" 
      % (os.path.getsize(tflite_quant_model_file) / float(2**20)))
print("Size of the tflite model after compression: %.2f Mb" 
      % (os.path.getsize(zip_tflite) / float(2**20)))
Size of the tflite model before compression: 3.13 Mb
Size of the tflite model after compression: 0.61 Mb

The size of the quantized model is roughly 1/4 of the orignial one.

interpreter = tf.lite.Interpreter(model_path=str(tflite_quant_model_file))
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

print(eval_model(interpreter, x_test, y_test))
Accuracy after 1000 images: 0.991000
Accuracy after 2000 images: 0.989000
Accuracy after 3000 images: 0.987667
Accuracy after 4000 images: 0.988750
Accuracy after 5000 images: 0.988600
Accuracy after 6000 images: 0.990333
Accuracy after 7000 images: 0.991000
Accuracy after 8000 images: 0.992000
Accuracy after 9000 images: 0.992667
Accuracy after 10000 images: 0.992600
0.9926

Conclusion

In this tutorial, we showed you how to create sparse models with the TensorFlow model optimization toolkit weight pruning API. Right now, this allows you to create models that take significant less space on disk. The resulting model can also be more efficiently implemented to avoid computation; in the future TensorFlow Lite will provide such capabilities.

More specifically, we walked you through an end-to-end example of training a simple MNIST model that used the weight pruning API. We showed you how to convert it to the Tensorflow Lite format for mobile deployment, and demonstrated how with simple file compression the model size was reduced 5x.

We encourage you to try this new capability on your Keras models, which can be particularly important for deployment in resource-constraint environments.