Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

Modify a keras layer or model to be pruned during training.

    pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
    block_size=(1, 1),

This function wraps a keras model or layer with pruning functionality which sparsifies the layer's weights during training. For example, using this with 50% sparsity will ensure that 50% of the layer's weights are zero.

The function accepts either a single keras layer (subclass of keras.layers.Layer), list of keras layers or a keras model (instance of keras.models.Model) and handles them appropriately.

If it encounters a layer it does not know how to handle, it will throw an error. While pruning an entire model, even a single unknown layer would lead to an error.

Prune a model:

pruning_params = {
    'pruning_schedule': ConstantSparsity(0.5, 0),
    'block_size': (1, 1),
    'block_pooling_type': 'AVG'

model = prune_low_magnitude(
        layers.Dense(10, activation='relu', input_shape=(100,)),
        layers.Dense(2, activation='sigmoid')
    ]), **pruning_params)

Prune a layer:

pruning_params = {
    'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
        final_sparsity=0.8, begin_step=1000, end_step=2000),
    'block_size': (2, 3),
    'block_pooling_type': 'MAX'

model = keras.Sequential([
    layers.Dense(10, activation='relu', input_shape=(100,)),
    prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)


  • to_prune: A single keras layer, list of keras layers, or a tf.keras.Model instance.
  • pruning_schedule: A PruningSchedule object that controls pruning rate throughout training.
  • block_size: (optional) The dimensions (height, weight) for the block sparse pattern in rank-2 weight tensors.
  • block_pooling_type: (optional) The function to use to pool weights in the block. Must be 'AVG' or 'MAX'.
  • **kwargs: Additional keyword arguments to be passed to the keras layer. Ignored when to_prune is not a keras layer.


Layer or model modified with pruning wrappers.


  • ValueError: if the keras layer is unsupported, or the keras model contains an unsupported layer.