Quantization aware training comprehensive guide

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Welcome to the comprehensive guide for Keras quantization aware training.

This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the API docs.

The following use cases are covered:

  • Deploy a model with 8-bit quantization with these steps.
    • Define a quantization aware model.
    • For Keras HDF5 models only, use special checkpointing and deserialization logic. Training is otherwise standard.
    • Create a quantized model from the quantization aware one.
  • Experiment with quantization.
    • Anything for experimentation has no supported path to deployment.
    • Custom Keras layers fall under experimentation.

Setup

For finding the APIs you need and understanding purposes, you can run but skip reading this section.

! pip install -q tensorflow
! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tf_keras as keras

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = keras.Sequential([
      keras.layers.Dense(20, input_shape=input_shape),
      keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model= setup_model()

  model.compile(
      loss=keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def setup_pretrained_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  model.load_weights(pretrained_weights)
  return model

setup_model()
pretrained_weights = setup_pretrained_weights()
2024-03-09 12:29:37.526315: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Define quantization aware model

By defining models in the following ways, there are available paths to deployment to backends listed in the overview page. By default, 8-bit quantization is used.

Quantize whole model

Your use case:

  • Subclassed models are not supported.

Tips for better model accuracy:

  • Try "Quantize some layers" to skip quantizing the layers that reduce accuracy the most.
  • It's generally better to finetune with quantization aware training as opposed to training from scratch.

To make the whole model aware of quantization, apply tfmot.quantization.keras.quantize_model to the model.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer (QuantizeLa  (None, 20)                3         
 yer)                                                            
                                                                 
 quant_dense_2 (QuantizeWra  (None, 20)                425       
 pperV2)                                                         
                                                                 
 quant_flatten_2 (QuantizeW  (None, 20)                1         
 rapperV2)                                                       
                                                                 
=================================================================
Total params: 429 (1.68 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 9 (36.00 Byte)
_________________________________________________________________

Quantize some layers

Quantizing a model can have a negative effect on accuracy. You can selectively quantize layers of a model to explore the trade-off between accuracy, speed, and model size.

Your use case:

  • To deploy to a backend that only works well with fully quantized models (e.g. EdgeTPU v1, most DSPs), try "Quantize whole model".

Tips for better model accuracy:

  • It's generally better to finetune with quantization aware training as opposed to training from scratch.
  • Try quantizing the later layers instead of the first layers.
  • Avoid quantizing critical layers (e.g. attention mechanism).

In the example below, quantize only the Dense layers.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `quantize_annotate_layer` to annotate that only the 
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
  if isinstance(layer, keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer

# Use `keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = keras.models.clone_model(
    base_model,
    clone_function=apply_quantization_to_dense,
)

# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_1 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_3 (QuantizeWra  (None, 20)                425       
 pperV2)                                                         
                                                                 
 flatten_3 (Flatten)         (None, 20)                0         
                                                                 
=================================================================
Total params: 428 (1.67 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 8 (32.00 Byte)
_________________________________________________________________

While this example used the type of the layer to decide what to quantize, the easiest way to quantize a particular layer is to set its name property, and look for that name in the clone_function.

print(base_model.layers[0].name)
dense_3

More readable but potentially lower model accuracy

This is not compatible with finetuning with quantization aware training, which is why it may be less accurate than the above examples.

Functional example

# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(10))(i)
o = keras.layers.Flatten()(x)
annotated_model = keras.Model(inputs=i, outputs=o)

# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 20)]              0         
                                                                 
 quantize_layer_2 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_4 (QuantizeWra  (None, 10)                215       
 pperV2)                                                         
                                                                 
 flatten_4 (Flatten)         (None, 10)                0         
                                                                 
=================================================================
Total params: 218 (872.00 Byte)
Trainable params: 210 (840.00 Byte)
Non-trainable params: 8 (32.00 Byte)
_________________________________________________________________

Sequential example

# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = keras.Sequential([
  tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(20, input_shape=input_shape)),
  keras.layers.Flatten()
])

# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

quant_aware_model.summary()
Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_3 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_5 (QuantizeWra  (None, 20)                425       
 pperV2)                                                         
                                                                 
 flatten_5 (Flatten)         (None, 20)                0         
                                                                 
=================================================================
Total params: 428 (1.67 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 8 (32.00 Byte)
_________________________________________________________________

Checkpoint and deserialize

Your use case: this code is only needed for the HDF5 model format (not HDF5 weights or other formats).

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)

# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)

# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
  loaded_model = keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_4 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_6 (QuantizeWra  (None, 20)                425       
 pperV2)                                                         
                                                                 
 quant_flatten_6 (QuantizeW  (None, 20)                1         
 rapperV2)                                                       
                                                                 
=================================================================
Total params: 429 (1.68 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 9 (36.00 Byte)
_________________________________________________________________

Create and deploy quantized model

In general, reference the documentation for the deployment backend that you will use.

This is an example for the TFLite backend.

base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)

# Typically you train the model here.

converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
1/1 [==============================] - 1s 684ms/step - loss: 16.1181 - accuracy: 0.0000e+00
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpyo_u4d_8/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpyo_u4d_8/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709987395.907073   23976 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987395.907116   23976 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

Experiment with quantization

Your use case: using the following APIs means that there is no supported path to deployment. For instance, TFLite conversion and kernel implementations only support 8-bit quantization. The features are also experimental and not subject to backward compatibility.

Setup: DefaultDenseQuantizeConfig

Experimenting requires using tfmot.quantization.keras.QuantizeConfig, which describes how to quantize the weights, activations, and outputs of a layer.

Below is an example that defines the same QuantizeConfig used for the Dense layer in the API defaults.

During the forward propagation in this example, the LastValueQuantizer returned in get_weights_and_quantizers is called with layer.kernel as the input, producing an output. The output replaces layer.kernel in the original forward propagation of the Dense layer, via the logic defined in set_quantize_weights. The same idea applies to the activations and outputs.

LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
      return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]

    def set_quantize_weights(self, layer, quantize_weights):
      # Add this line for each item returned in `get_weights_and_quantizers`
      # , in the same order
      layer.kernel = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
      # Add this line for each item returned in `get_activations_and_quantizers`
      # , in the same order.
      layer.activation = quantize_activations[0]

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
      return []

    def get_config(self):
      return {}

Quantize custom Keras layer

This example uses the DefaultDenseQuantizeConfig to quantize the CustomLayer.

Applying the configuration is the same across the "Experiment with quantization" use cases.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class CustomLayer(keras.layers.Dense):
  pass

model = quantize_annotate_model(keras.Sequential([
   quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
   keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
  {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
   'CustomLayer': CustomLayer}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_6 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_custom_layer (Quanti  (None, 20)                425       
 zeWrapperV2)                                                    
                                                                 
 quant_flatten_9 (QuantizeW  (None, 20)                1         
 rapperV2)                                                       
                                                                 
=================================================================
Total params: 429 (1.68 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 9 (36.00 Byte)
_________________________________________________________________

Modify quantization parameters

Common mistake: quantizing the bias to fewer than 32-bits usually harms model accuracy too much.

This example modifies the Dense layer to use 4-bits for its weights instead of the default 8-bits. The rest of the model continues to use API defaults.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]

Applying the configuration is the same across the "Experiment with quantization" use cases.

model = quantize_annotate_model(keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_7 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_9 (QuantizeWra  (None, 20)                425       
 pperV2)                                                         
                                                                 
 quant_flatten_10 (Quantize  (None, 20)                1         
 WrapperV2)                                                      
                                                                 
=================================================================
Total params: 429 (1.68 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 9 (36.00 Byte)
_________________________________________________________________

Modify parts of layer to quantize

This example modifies the Dense layer to skip quantizing the activation. The rest of the model continues to use API defaults.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    def get_activations_and_quantizers(self, layer):
      # Skip quantizing activations.
      return []

    def set_quantize_activations(self, layer, quantize_activations):
      # Empty since `get_activaations_and_quantizers` returns
      # an empty list.
      return

Applying the configuration is the same across the "Experiment with quantization" use cases.

model = quantize_annotate_model(keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_8 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_10 (QuantizeWr  (None, 20)                423       
 apperV2)                                                        
                                                                 
 quant_flatten_11 (Quantize  (None, 20)                1         
 WrapperV2)                                                      
                                                                 
=================================================================
Total params: 427 (1.67 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 7 (28.00 Byte)
_________________________________________________________________

Use custom quantization algorithm

The tfmot.quantization.keras.quantizers.Quantizer class is a callable that can apply any algorithm to its inputs.

In this example, the inputs are the weights, and we apply the math in the FixedRangeQuantizer __call__ function to the weights. Instead of the original weights values, the output of the FixedRangeQuantizer is now passed to whatever would have used the weights.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
  """Quantizer which forces outputs to be between -1 and 1."""

  def build(self, tensor_shape, name, layer):
    # Not needed. No new TensorFlow variables needed.
    return {}

  def __call__(self, inputs, training, weights, **kwargs):
    return keras.backend.clip(inputs, -1.0, 1.0)

  def get_config(self):
    # Not needed. No __init__ parameters to serialize.
    return {}


class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      # Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
      return [(layer.kernel, FixedRangeQuantizer())]

Applying the configuration is the same across the "Experiment with quantization" use cases.

model = quantize_annotate_model(keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this `Dense` layer.
   quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer_9 (Quantize  (None, 20)                3         
 Layer)                                                          
                                                                 
 quant_dense_11 (QuantizeWr  (None, 20)                423       
 apperV2)                                                        
                                                                 
 quant_flatten_12 (Quantize  (None, 20)                1         
 WrapperV2)                                                      
                                                                 
=================================================================
Total params: 427 (1.67 KB)
Trainable params: 420 (1.64 KB)
Non-trainable params: 7 (28.00 Byte)
_________________________________________________________________