Trim insignificant weights

This document provides an overview on model pruning to help you determine how it fits with your use case.

Overview

Magnitude-based weight pruning gradually zeroes out model weights during the training process to achieve model sparsity. Sparse models are easier to compress, and we can skip the zeroes during inference for latency improvements.

This technique brings improvements via model compression. In the future, framework support for this technique will provide latency improvements. We've seen up to 6x improvements in model compression with minimal loss of accuracy.

The technique is being evaluated in various speech applications, such as speech recognition and text-to-speech, and has been experimented on across various vision and translation models.

API Compatibility Matrix

Users can apply pruning with the following APIs:

  • Model building: keras with only Sequential and Functional models
  • TensorFlow versions: TF 1.x for versions 1.14+ and 2.x.
    • tf.compat.v1 with a TF 2.X package and tf.compat.v2 with a TF 1.X package are not supported.
  • TensorFlow execution mode: both graph and eager
  • Distributed training: tf.distribute with only graph execution

It is on our roadmap to add support in the following areas:

Results

Image Classification

Model Non-sparse Top-1 Accuracy Random Sparse Accuracy Random Sparsity Structured Sparse Accuracy Structured Sparsity
InceptionV3 78.1% 78.0% 50% 75.8% 2 by 4
76.1%75%
74.6%87.5%
MobilenetV1 22471.04%70.84%50%67.35%2 by 4
MobilenetV2 22471.77%69.64%50%66.75%2 by 4

The models were tested on Imagenet.

Translation

Model Non-sparse BLEU Sparse BLEU Sparsity
GNMT EN-DE 26.77 26.86 80%
26.5285%
26.1990%
GNMT DE-EN 29.47 29.50 80%
29.2485%
28.8190%

The models use WMT16 German and English dataset with news-test2013 as the dev set and news-test2015 as the test set.

Keyword spotting model

DS-CNN-L is a keyword spotting model created for edge devices. It can be found in ARM software's examples repository.

Model Non-sparse Accuracy Structured Sparse Accuracy (2 by 4 pattern) Random Sparse Accuracy (target sparsity 50%)
DS-CNN-L 95.23 94.33 94.84

Examples

In addition to the Prune with Keras tutorial, see the following examples:

  • Train a CNN model on the MNIST handwritten digit classification task with pruning: code
  • Train a LSTM on the IMDB sentiment classification task with pruning: code

For background, see To prune, or not to prune: exploring the efficacy of pruning for model compression [paper].