tfmot.clustering.keras.cluster_weights

Modify a keras layer or model to be clustered during training.

This function wraps a keras model or layer with clustering functionality which clusters the layer's weights during training. For examples, using this with number_of_clusters equals 8 will ensure that each weight tensor has no more than 8 unique values.

Before passing to the clustering API, a model should already be trained and show some acceptable performance on the testing/validation sets.

The function accepts either a single keras layer (subclass of keras.layers.Layer), list of keras layers or a keras model (instance of keras.models.Model) and handles them appropriately.

If it encounters a layer it does not know how to handle, it will throw an error. While clustering an entire model, even a single unknown layer would lead to an error.

Cluster a model:

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init':
      CentroidInitialization.DENSITY_BASED
}

clustered_model = cluster_weights(original_model, **clustering_params)

Cluster a layer:

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init':
      CentroidInitialization.DENSITY_BASED
}

model = keras.Sequential([
    layers.Dense(10, activation='relu', input_shape=(100,)),
    cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])

to_cluster A single keras layer, list of keras layers, or a tf.keras.Model instance.
number_of_clusters the number of cluster centroids to form when clustering a layer/model. For example, if number_of_clusters=8 then only 8 unique values will be used in each weight array.
cluster_centroids_init tfmot.clustering.keras.CentroidInitialization instance that determines how the cluster centroids will be initialized.
**kwargs Additional keyword arguments to be passed to the keras layer. Ignored when to_cluster is not a keras layer.

Layer or model modified to include clustering related metadata.

ValueError if the keras layer is unsupported, or the keras model contains an unsupported layer.