![]() |
![]() |
![]() |
![]() |
Welcome to the comprehensive guide for weight clustering, part of the TensorFlow Model Optimization toolkit.
This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the API docs:
- If you want to see the benefits of weight clustering and what's supported, check the overview.
- For a single end-to-end example, see the weight clustering example.
In this guide, the following use cases are covered:
- Define a clustered model.
- Checkpoint and deserialize a clustered model.
- Improve the accuracy of the clustered model.
- For deployment only, you must take steps to see compression benefits.
Setup
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tempfile
import os
import tensorflow_model_optimization as tfmot
input_dim = 20
output_dim = 20
x_train = np.random.randn(1, input_dim).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)
def setup_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(input_dim, input_shape=[input_dim]),
tf.keras.layers.Flatten()
])
return model
def train_model(model):
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.summary()
model.fit(x_train, y_train)
return model
def save_model_weights(model):
_, pretrained_weights = tempfile.mkstemp('.h5')
model.save_weights(pretrained_weights)
return pretrained_weights
def setup_pretrained_weights():
model= setup_model()
model = train_model(model)
pretrained_weights = save_model_weights(model)
return pretrained_weights
def setup_pretrained_model():
model = setup_model()
pretrained_weights = setup_pretrained_weights()
model.load_weights(pretrained_weights)
return model
def save_model_file(model):
_, keras_file = tempfile.mkstemp('.h5')
model.save(keras_file, include_optimizer=False)
return keras_file
def get_gzipped_model_size(model):
# It returns the size of the gzipped model in bytes.
import os
import zipfile
keras_file = save_model_file(model)
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(keras_file)
return os.path.getsize(zipped_file)
setup_model()
pretrained_weights = setup_pretrained_weights()
Define a clustered model
Cluster a whole model (sequential and functional)
Tips for better model accuracy:
- You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.
- In some cases, clustering certain layers has a detrimental effect on model accuracy. Check "Cluster some layers" to see how to skip clustering the layers that affect accuracy the most.
To cluster all layers, apply tfmot.clustering.keras.cluster_weights
to the model.
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 3,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
model = setup_model()
model.load_weights(pretrained_weights)
clustered_model = cluster_weights(model, **clustering_params)
clustered_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_dense_2 (ClusterWeig (None, 20) 423 _________________________________________________________________ cluster_flatten_2 (ClusterWe (None, 20) 0 ================================================================= Total params: 423 Trainable params: 23 Non-trainable params: 400 _________________________________________________________________
Cluster some layers (sequential and functional models)
Tips for better model accuracy:
- You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.
- Cluster later layers with more redundant parameters (e.g.
tf.keras.layers.Dense
,tf.keras.layers.Conv2D
), as opposed to the early layers. - Freeze early layers prior to the clustered layers during fine-tuning. Treat the number of frozen layers as a hyperparameter. Empirically, freezing most early layers is ideal for the current clustering API.
- Avoid clustering critical layers (e.g. attention mechanism).
More: the tfmot.clustering.keras.cluster_weights
API docs provide details on how to vary the clustering configuration per layer.
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights)
# Helper function uses `cluster_weights` to make only
# the Dense layers train with clustering
def apply_clustering_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return cluster_weights(layer, **clustering_params)
return layer
# Use `tf.keras.models.clone_model` to apply `apply_clustering_to_dense`
# to the layers of the model.
clustered_model = tf.keras.models.clone_model(
base_model,
clone_function=apply_clustering_to_dense,
)
clustered_model.summary()
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_dense_3 (ClusterWeig (None, 20) 423 _________________________________________________________________ flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 423 Trainable params: 23 Non-trainable params: 400 _________________________________________________________________
Checkpoint and deserialize a clustered model
Your use case: this code is only needed for the HDF5 model format (not HDF5 weights or other formats).
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights)
clustered_model = cluster_weights(base_model, **clustering_params)
# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
clustered_model.save(keras_model_file, include_optimizer=True)
# `cluster_scope` is needed for deserializing HDF5 models.
with tfmot.clustering.keras.cluster_scope():
loaded_model = tf.keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_dense_4 (ClusterWeig (None, 20) 423 _________________________________________________________________ cluster_flatten_4 (ClusterWe (None, 20) 0 ================================================================= Total params: 423 Trainable params: 23 Non-trainable params: 400 _________________________________________________________________
Improve the accuracy of the clustered model
For your specific use case, there are tips you can consider:
Centroid initialization plays a key role in the final optimized model accuracy. In general, linear initialization outperforms density and random initialization since it does not tend to miss large weights. However, density initialization has been observed to give better accuracy for the case of using very few clusters on weights with bimodal distributions.
Set a learning rate that is lower than the one used in training when fine-tuning the clustered model.
For general ideas to improve model accuracy, look for tips for your use case(s) under "Define a clustered model".
Deployment
Export model with size compression
Common mistake: both strip_clustering
and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering.
model = setup_model()
clustered_model = cluster_weights(model, **clustering_params)
clustered_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
clustered_model.fit(
x_train,
y_train
)
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
print("final model")
final_model.summary()
print("\n")
print("Size of gzipped clustered model without stripping: %.2f bytes"
% (get_gzipped_model_size(clustered_model)))
print("Size of gzipped clustered model with stripping: %.2f bytes"
% (get_gzipped_model_size(final_model)))
1/1 [==============================] - 0s 984us/step - loss: 16.1181 - accuracy: 0.0000e+00 final model Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_5 (Dense) (None, 20) 420 _________________________________________________________________ flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 420 Trainable params: 420 Non-trainable params: 0 _________________________________________________________________ Size of gzipped clustered model without stripping: 1809.00 bytes Size of gzipped clustered model with stripping: 1399.00 bytes