tfmot.sparsity.keras.PruningPolicy

Specifies what layers to prune in the model.

PruningPolicy controls application of PruneLowMagnitude wrapper on per-layer basis and checks that the model contains only supported layers. PruningPolicy works together with prune_low_magnitude through which it provides fine-grained control over pruning in the model.

pruning_params = {
    'pruning_schedule': ConstantSparsity(0.5, 0),
    'block_size': (1, 1),
    'block_pooling_type': 'AVG'
}

model = prune_low_magnitude(
    keras.Sequential([
        layers.Dense(10, activation='relu', input_shape=(100,)),
        layers.Dense(2, activation='sigmoid')
    ]),
    pruning_policy=PruneForLatencyOnXNNPack(),
    **pruning_params)

You can inherit this class to write your own custom pruning policy.

The API is experimental and is subject to change.

Methods

allow_pruning

View source

Checks if pruning wrapper should be applied for the current layer.

Args
layer Current layer in the model.

Returns
True/False, whether the pruning wrapper should be applied for the layer.

ensure_model_supports_pruning

View source

Checks that the model contains only supported layers.

Args
model A tf.keras.Model instance which is going to be pruned.

Raises
ValueError if the keras model doesn't support pruning policy, i.e. keras model contains an unsupported layer.