Collaborative Optimization

Maintained by Arm ML Tooling

This document provides an overview of experimental APIs for combining various techniques to optimize machine learning models for deployment.


Collaborative optimization is an overarching process that encompasses various techniques to produce a model that, at deployment, exhibits the best balance of target characteristics such as inference speed, model size and accuracy.

The idea of collaborative optimizations is to build on individual techniques by applying them one after another to achieve the accumulated optimization effect. Various combinations of the following optimizations are possible:

The issue that arises when attempting to chain these techniques together is that applying one typically destroys the results of the preceding technique, spoiling the overall benefit of simultaneously applying all of them; for example, clustering doesn't preserve the sparsity introduced by the pruning API. To solve this problem, we introduce the following experimental collaborative optimization techniques:

These provide several deployment paths that could be used to compress a machine learning model and to take advantage of hardware acceleration at inference time. The diagram below demonstrates several deployment paths that can be explored in search for the model with desired deployment characteristics, where the leaf nodes are deployment-ready models, meaning they are partially or fully quantized and in tflite format. The green fill indicates steps where retraining/fine-tuning is required and a dashed red border highlights the collaborative optimization steps. The technique used to obtain a model at a given node is indicated in the corresponding label.

collaborative optimization

The direct, quantization-only (post-training or QAT) deployment path is omitted in the figure above.

The idea is to reach the fully optimized model at the third level of the above deployment tree; however, any of the other levels of optimization could prove satisfactory and achieve the required inference latency/accuracy trade-off, in which case no further optimization is needed. The recommended training process would be to iteratively go through the levels of the deployment tree applicable to the target deployment scenario and see if the model fulfils the inference latency requirements and, if not, use the corresponding collaborative optimization technique to compress the model further and repeat until the model is fully optimized (pruned, clustered, and quantized), if needed.

The figure below shows the density plots of sample weight kernel going through the collaborative optimization pipeline.

collaborative optimization density plot

The result is a quantized deployment model with a reduced number of unique values as well as a significant number of sparse weights, depending on the target sparsity specified at training time. Other than the significant model compression advantages, specific hardware support can take advantage of these sparse, clustered models to significantly reduce inference latency.


Below are some accuracy and compression results we obtained when experimenting with PQAT and CQAT collaborative optimization paths.

Sparsity-preserving Quantization aware training (PQAT)

ModelItemsBaselinePruned Model (50% sparsity)QAT ModelPQAT Model
DS-CNN-LFP32 Top1 Accuracy95.23%94.80%(Fake INT8) 94.721%(Fake INT8) 94.128%
 INT8 full integer quantization94.48%93.80%94.72%94.13%
 Compression528,128 → 434,879 (17.66%)528,128 → 334,154 (36.73%)512,224 → 403,261 (21.27%)512,032 → 303,997 (40.63%)
Mobilenet_v1-224FP32 Top 1 Accuracy70.99%70.11%(Fake INT8) 70.67%(Fake INT8) 70.29%
 INT8 full integer quantization69.37%67.82%70.67%70.29%
 Compression4,665,520 → 3,880,331 (16.83%)4,665,520 → 2,939,734 (37.00%)4,569,416 → 3,808,781 (16.65%)4,569,416 → 2,869,600 (37.20%)

Cluster-preserving Quantization aware training (CQAT)

ModelItemsBaselineClustered ModelQAT ModelCQAT Model
Mobilenet_v1 on CIFAR-10FP32 Top1 Accuracy94.88%94.48%(Fake INT8) 94.80%(Fake INT8) 94.60%
 INT8 full integer quantization94.65%94.41%94.77%94.52%
 Size3.00 MB2.00 MB2.84 MB1.94 MB
Mobilenet_v1 on ImageNetFP32 Top 1 Accuracy71.07%65.30%(Fake INT8) 70.39%(Fake INT8) 65.35%
 INT8 full integer quantization69.34%60.60%70.35%65.42%
 Compression4,665,568 → 3,886,277 (16.7%)4,665,568 → 3,035,752 (34.9%)4,569,416 → 3,804,871 (16.7%)4,569,472 → 2,912,655 (36.25%)


End-to-end example of the different deployment paths will be added in the future.