Modifies a keras layer or model to be clustered during training.

Used in the notebooks

Used in the guide

This function wraps a keras model or layer with clustering functionality which clusters the layer's weights during training. For examples, using this with number_of_clusters equals 8 will ensure that each weight tensor has no more than 8 unique values.

Before passing to the clustering API, a model should already be trained and show some acceptable performance on the testing/validation sets.

The function accepts either a single keras layer (subclass of keras.layers.Layer), list of keras layers or a keras model (instance of keras.models.Model) and handles them appropriately.

If it encounters a layer it does not know how to handle, it will throw an error. While clustering an entire model, even a single unknown layer would lead to an error.

Cluster a model:

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.DENSITY_BASED

clustered_model = cluster_weights(original_model, **clustering_params)

Cluster a layer:

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.DENSITY_BASED

model = tf.keras.Sequential([
    layers.Dense(10, activation='relu', input_shape=(100,)),
    cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)

to_cluster A single keras layer, list of keras layers, or a tf.keras.Model instance.
number_of_clusters the number of cluster centroids to form when clustering a layer/model. For example, if number_of_clusters=8 then only 8 unique values will be used in each weight array.
cluster_centroids_init enum value that determines how the cluster centroids will be initialized. Can have following values:

  1. RANDOM : centroids are sampled using the uniform distribution between the minimum and maximum weight values in a given layer
  2. DENSITY_BASED : density-based sampling. First, cumulative distribution function is built for weights, then y-axis is evenly spaced into number_of_clusters regions. After this the corresponding x values are obtained and used to initialize clusters centroids.
  3. LINEAR : cluster centroids are evenly spaced between the minimum and maximum values of a given weight
**kwargs Additional keyword arguments to be passed to the keras layer. Ignored when to_cluster is not a keras layer.

Layer or model modified to include clustering related metadata.

ValueError if the keras layer is unsupported, or the keras model contains an unsupported layer.