ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfmot.sparsity.keras.PruningPolicy

Specifies what layers to prune in the model.

PruningPolicy controls application of PruneLowMagnitude wrapper on per-layer basis and checks that the model contains only supported layers. PruningPolicy works together with prune_low_magnitude through which it provides fine-grained control over pruning in the model.

pruning_params = {
    'pruning_schedule': ConstantSparsity(0.5, 0),
    'block_size': (1, 1),
    'block_pooling_type': 'AVG'
}

model = prune_low_magnitude(
    keras.Sequential([
        layers.Dense(10, activation='relu', input_shape=(100,)),
        layers.Dense(2, activation='sigmoid')
    ]),
    pruning_policy=PruneForLatencyOnXNNPack(),
    **pruning_params)

You can inherit this class to write your own custom pruning policy.

The API is experimental and is subject to change.

Methods

allow_pruning

View source

Checks if pruning wrapper should be applied for the current layer.

Args
layer Current layer in the model.

Returns
True/False, whether the pruning wrapper should be applied for the layer.

ensure_model_supports_pruning

View source

Checks that the model contains only supported layers.

Args
model A tf.keras.Model instance which is going to be pruned.

Raises
ValueError if the keras model doesn't support pruning policy, i.e. keras model contains an unsupported layer.