Maintained by Arm ML Tooling
This document provides an overview of experimental APIs for combining various techniques to optimize machine learning models for deployment.
Collaborative optimization is an overarching process that encompasses various techniques to produce a model that, at deployment, exhibits the best balance of target characteristics such as inference speed, model size and accuracy.
The idea of collaborative optimizations is to build on individual techniques by applying them one after another to achieve the accumulated optimization effect. Various combinations of the following optimizations are possible:
- Weight pruning
- Weight clustering
The issue that arises when attempting to chain these techniques together is that applying one typically destroys the results of the preceding technique, spoiling the overall benefit of simultaneously applying all of them; for example, clustering doesn't preserve the sparsity introduced by the pruning API. To solve this problem, we introduce the following experimental collaborative optimization techniques:
- Sparsity preserving clustering
- Sparsity preserving quantization aware training (PQAT)
- Cluster preserving quantization aware training (CQAT)
- Sparsity and cluster preserving quantization aware training
These provide several deployment paths that could be used to compress a machine learning model and to take advantage of hardware acceleration at inference time. The diagram below demonstrates several deployment paths that can be explored in search for the model with desired deployment characteristics, where the leaf nodes are deployment-ready models, meaning they are partially or fully quantized and in tflite format. The green fill indicates steps where retraining/fine-tuning is required and a dashed red border highlights the collaborative optimization steps. The technique used to obtain a model at a given node is indicated in the corresponding label.
The direct, quantization-only (post-training or QAT) deployment path is omitted in the figure above.
The idea is to reach the fully optimized model at the third level of the above deployment tree; however, any of the other levels of optimization could prove satisfactory and achieve the required inference latency/accuracy trade-off, in which case no further optimization is needed. The recommended training process would be to iteratively go through the levels of the deployment tree applicable to the target deployment scenario and see if the model fulfils the inference latency requirements and, if not, use the corresponding collaborative optimization technique to compress the model further and repeat until the model is fully optimized (pruned, clustered, and quantized), if needed.
The figure below shows the density plots of sample weight kernel going through the collaborative optimization pipeline.
The result is a quantized deployment model with a reduced number of unique values as well as a significant number of sparse weights, depending on the target sparsity specified at training time. Other than the significant model compression advantages, specific hardware support can take advantage of these sparse, clustered models to significantly reduce inference latency.
Below are some accuracy and compression results we obtained when experimenting with PQAT and CQAT collaborative optimization paths.
Sparsity-preserving Quantization aware training (PQAT)
|Model||Items||Baseline||Pruned Model (50% sparsity)||QAT Model||PQAT Model|
|DS-CNN-L||FP32 Top1 Accuracy||95.23%||94.80%||(Fake INT8) 94.721%||(Fake INT8) 94.128%|
|INT8 full integer quantization||94.48%||93.80%||94.72%||94.13%|
|Compression||528,128 → 434,879 (17.66%)||528,128 → 334,154 (36.73%)||512,224 → 403,261 (21.27%)||512,032 → 303,997 (40.63%)|
|Mobilenet_v1-224||FP32 Top 1 Accuracy||70.99%||70.11%||(Fake INT8) 70.67%||(Fake INT8) 70.29%|
|INT8 full integer quantization||69.37%||67.82%||70.67%||70.29%|
|Compression||4,665,520 → 3,880,331 (16.83%)||4,665,520 → 2,939,734 (37.00%)||4,569,416 → 3,808,781 (16.65%)||4,569,416 → 2,869,600 (37.20%)|
Cluster-preserving Quantization aware training (CQAT)
|Model||Items||Baseline||Clustered Model||QAT Model||CQAT Model|
|Mobilenet_v1 on CIFAR-10||FP32 Top1 Accuracy||94.88%||94.48%||(Fake INT8) 94.80%||(Fake INT8) 94.60%|
|INT8 full integer quantization||94.65%||94.41%||94.77%||94.52%|
|Size||3.00 MB||2.00 MB||2.84 MB||1.94 MB|
|Mobilenet_v1 on ImageNet||FP32 Top 1 Accuracy||71.07%||65.30%||(Fake INT8) 70.39%||(Fake INT8) 65.35%|
|INT8 full integer quantization||69.34%||60.60%||70.35%||65.42%|
|Compression||4,665,568 → 3,886,277 (16.7%)||4,665,568 → 3,035,752 (34.9%)||4,569,416 → 3,804,871 (16.7%)||4,569,472 → 2,912,655 (36.25%)|