Umfassender Leitfaden für quantisierungsbewusstes Training

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Willkommen zum umfassenden Leitfaden für Keras Quantisierungsbewusstes Training.

Diese Seite dokumentiert verschiedene Anwendungsfälle und zeigt, wie Sie die API für jeden verwenden. Sobald Sie wissen, welche APIs Sie benötigen, finden Sie die Parameter und die Details auf niedriger Ebene in den API-Dokumenten .

  • Weitere Informationen zu den Vorteilen von quantisierungsbewusstem Training und den unterstützten Inhalten finden Sie in der Übersicht .
  • Ein einzelnes End-to-End-Beispiel finden Sie im Beispiel für ein quantisierungsbewusstes Training .

Folgende Anwendungsfälle werden abgedeckt:

  • Stellen Sie mit diesen Schritten ein Modell mit 8-Bit-Quantisierung bereit.
    • Definieren Sie ein quantisierungsbewusstes Modell.
    • Verwenden Sie nur für Keras HDF5-Modelle eine spezielle Prüfpunkt- und Deserialisierungslogik. Training ist ansonsten Standard.
    • Erstellen Sie ein quantisiertes Modell aus dem quantisierungsbewussten Modell.
  • Experimentieren Sie mit der Quantisierung.
    • Alles für Experimente hat keinen unterstützten Pfad zur Bereitstellung.
    • Benutzerdefinierte Keras-Ebenen werden getestet.

Einrichten

Um die benötigten APIs zu finden und zu verstehen, können Sie diesen Abschnitt ausführen, aber das Lesen überspringen.

! pip uninstall -y tensorflow
! pip install -q tf-nightly
! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model= setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def setup_pretrained_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  model.load_weights(pretrained_weights)
  return model

setup_model()
pretrained_weights = setup_pretrained_weights()

Quantisierungsbewusstes Modell definieren

Durch das Definieren von Modellen auf die folgende Weise stehen auf der Übersichtsseite aufgelistete verfügbare Pfade zur Bereitstellung in Back-Ends zur Verfügung. Standardmäßig wird 8-Bit-Quantisierung verwendet.

Gesamtes Modell quantisieren

Ihr Anwendungsfall:

  • Untergeordnete Modelle werden nicht unterstützt.

Tipps für eine bessere Modellgenauigkeit:

  • Versuchen Sie "Einige Ebenen quantisieren", um die Quantisierung der Ebenen zu überspringen, die die Genauigkeit am stärksten beeinträchtigen.
  • Es ist im Allgemeinen besser, eine Feinabstimmung mit einem quantisierungsbewussten Training vorzunehmen, als mit einem Training von Grund auf neu.

Um das gesamte Modell auf die Quantisierung aufmerksam zu machen, wenden Sie tfmot.quantization.keras.quantize_model auf das Modell an.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer (QuantizeLaye (None, 20)                3         
_________________________________________________________________
quant_dense_2 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_2 (QuantizeWra (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

Quantisieren Sie einige Ebenen

Die Quantisierung eines Modells kann sich negativ auf die Genauigkeit auswirken. Sie können Layer eines Modells selektiv quantisieren, um den Kompromiss zwischen Genauigkeit, Geschwindigkeit und Modellgröße zu untersuchen.

Ihr Anwendungsfall:

  • Um ein Backend bereitzustellen, das nur mit vollständig quantisierten Modellen gut funktioniert (zB EdgeTPU v1, die meisten DSPs), versuchen Sie "Quantize Whole Model".

Tipps für eine bessere Modellgenauigkeit:

  • Es ist im Allgemeinen besser, eine Feinabstimmung mit einem quantisierungsbewussten Training vorzunehmen, als mit einem Training von Grund auf neu.
  • Versuchen Sie, die späteren Ebenen anstelle der ersten Ebenen zu quantisieren.
  • Vermeiden Sie die Quantisierung kritischer Schichten (zB Aufmerksamkeitsmechanismus).

Quantisieren Sie im Beispiel unten nur die Dense Ebenen.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `quantize_annotate_layer` to annotate that only the 
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_quantization_to_dense,
)

# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_1 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_3 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 428
Trainable params: 420
Non-trainable params: 8
_________________________________________________________________

Während in diesem Beispiel der Typ der Ebene verwendet wurde, um zu entscheiden, was zu quantisieren ist, besteht der einfachste Weg zum Quantisieren einer bestimmten Ebene darin, ihre name Eigenschaft clone_function und in der clone_function nach diesem Namen zu clone_function .

print(base_model.layers[0].name)
dense_3

Besser lesbar, aber möglicherweise niedrigere Modellgenauigkeit

Dies ist nicht mit der Feinabstimmung mit quantisierungsbewusstem Training vereinbar, weshalb es möglicherweise weniger genau ist als die obigen Beispiele.

Funktionsbeispiel

# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = tf.keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
annotated_model = tf.keras.Model(inputs=i, outputs=o)

# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
quantize_layer_2 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_4 (QuantizeWrapp (None, 10)                215       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
=================================================================
Total params: 218
Trainable params: 210
Non-trainable params: 8
_________________________________________________________________

Sequenzielles Beispiel

# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = tf.keras.Sequential([
  tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

quant_aware_model.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_3 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_5 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
=================================================================
Total params: 428
Trainable params: 420
Non-trainable params: 8
_________________________________________________________________

Checkpoint und Deserialisierung

Ihr Anwendungsfall: Dieser Code wird nur für das HDF5-Modellformat benötigt (keine HDF5-Gewichte oder andere Formate).

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)

# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)

# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_4 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_6 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_6 (QuantizeWra (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

Quantisiertes Modell erstellen und bereitstellen

Verweisen Sie im Allgemeinen auf die Dokumentation für das Bereitstellungs-Back-End, das Sie verwenden werden.

Dies ist ein Beispiel für das TFLite-Backend.

base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)

# Typically you train the model here.

converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
1/1 [==============================] - 0s 229ms/step - loss: 16.1181 - accuracy: 0.0000e+00
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:absl:Found untraced functions such as dense_7_layer_call_and_return_conditional_losses, dense_7_layer_call_fn, flatten_7_layer_call_and_return_conditional_losses, flatten_7_layer_call_fn, dense_7_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmph82hcqfi/assets
INFO:tensorflow:Assets written to: /tmp/tmph82hcqfi/assets

Experimentieren Sie mit der Quantisierung

Ihr Anwendungsfall : Die Verwendung der folgenden APIs bedeutet, dass es keinen unterstützten Bereitstellungspfad gibt. Beispielsweise unterstützen die TFLite-Konvertierung und Kernel-Implementierungen nur die 8-Bit-Quantisierung. Die Funktionen sind ebenfalls experimentell und unterliegen nicht der Abwärtskompatibilität.

Setup: DefaultDenseQuantizeConfig

Das Experimentieren erfordert die Verwendung von tfmot.quantization.keras.QuantizeConfig , das beschreibt, wie die Gewichtungen, Aktivierungen und Ausgaben einer Schicht quantisiert werden.

Unten ist ein Beispiel, das dieselbe QuantizeConfig definiert, die für die Dense Ebene in den API-Standards verwendet wird.

Während der Vorwärtspropagation in diesem Beispiel wird der in LastValueQuantizer zurückgegebene get_weights_and_quantizers mit layer.kernel als Eingabe aufgerufen und erzeugt eine Ausgabe. Die Ausgabe ersetzt layer.kernel in der ursprünglichen Vorwärtsausbreitung der Dense Schicht über die in set_quantize_weights definierte set_quantize_weights . Die gleiche Idee gilt für die Aktivierungen und Ausgänge.

LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
      return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]

    def set_quantize_weights(self, layer, quantize_weights):
      # Add this line for each item returned in `get_weights_and_quantizers`
      # , in the same order
      layer.kernel = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
      # Add this line for each item returned in `get_activations_and_quantizers`
      # , in the same order.
      layer.activation = quantize_activations[0]

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
      return []

    def get_config(self):
      return {}

Benutzerdefinierte Keras-Ebene quantisieren

Dieses Beispiel verwendet die DefaultDenseQuantizeConfig die quantisiert CustomLayer .

Das Anwenden der Konfiguration ist in den Anwendungsfällen "Experiment mit Quantisierung" gleich.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class CustomLayer(tf.keras.layers.Dense):
  pass

model = quantize_annotate_model(tf.keras.Sequential([
   quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
  {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
   'CustomLayer': CustomLayer}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_6 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_custom_layer (Quantize (None, 20)                425       
_________________________________________________________________
quant_flatten_9 (QuantizeWra (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

Quantisierungsparameter ändern

Häufiger Fehler: Die Quantisierung des Bias auf weniger als 32 Bit beeinträchtigt normalerweise die Modellgenauigkeit zu sehr.

In diesem Beispiel wird die Ebene Dense geändert, dass für ihre Gewichtungen 4 Bit anstelle der standardmäßigen 8 Bit verwendet werden. Der Rest des Modells verwendet weiterhin API-Standardwerte.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]

Das Anwenden der Konfiguration ist in den Anwendungsfällen "Experiment mit Quantisierung" gleich.

model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_7 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_9 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_10 (QuantizeWr (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

Ändern Sie Teile der Ebene, um sie zu quantisieren

In diesem Beispiel wird die Ebene Dense modifiziert, um die Quantisierung der Aktivierung zu überspringen. Der Rest des Modells verwendet weiterhin API-Standardwerte.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    def get_activations_and_quantizers(self, layer):
      # Skip quantizing activations.
      return []

    def set_quantize_activations(self, layer, quantize_activations):
      # Empty since `get_activaations_and_quantizers` returns
      # an empty list.
      return

Das Anwenden der Konfiguration ist in den Anwendungsfällen "Experiment mit Quantisierung" gleich.

model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_8 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_10 (QuantizeWrap (None, 20)                423       
_________________________________________________________________
quant_flatten_11 (QuantizeWr (None, 20)                1         
=================================================================
Total params: 427
Trainable params: 420
Non-trainable params: 7
_________________________________________________________________

Benutzerdefinierten Quantisierungsalgorithmus verwenden

Die Klasse tfmot.quantization.keras.quantizers.Quantizer ist ein Callable, das jeden beliebigen Algorithmus auf seine Eingaben anwenden kann.

In diesem Beispiel sind die Eingaben die Gewichte, und wir wenden die Mathematik in der Funktion FixedRangeQuantizer __call__ auf die Gewichte an. Anstelle der ursprünglichen Gewichtungswerte wird die Ausgabe des FixedRangeQuantizer jetzt an das weitergegeben, was die Gewichte verwendet hätte.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
  """Quantizer which forces outputs to be between -1 and 1."""

  def build(self, tensor_shape, name, layer):
    # Not needed. No new TensorFlow variables needed.
    return {}

  def __call__(self, inputs, training, weights, **kwargs):
    return tf.keras.backend.clip(inputs, -1.0, 1.0)

  def get_config(self):
    # Not needed. No __init__ parameters to serialize.
    return {}


class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      # Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
      return [(layer.kernel, FixedRangeQuantizer())]

Das Anwenden der Konfiguration ist in den Anwendungsfällen "Experiment mit Quantisierung" gleich.

model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this `Dense` layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_9 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_11 (QuantizeWrap (None, 20)                423       
_________________________________________________________________
quant_flatten_12 (QuantizeWr (None, 20)                1         
=================================================================
Total params: 427
Trainable params: 420
Non-trainable params: 7
_________________________________________________________________