양자화 인식 훈련 종합 가이드

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

Keras 양자화 인식 훈련에 관한 종합 가이드를 시작합니다.

이 페이지는 다양한 사용 사례를 문서화하고 각각에 대해 API를 사용하는 방법을 보여줍니다. 필요한 API를 알고 나면, API 문서에서 매개변수와 하위 수준의 세부 정보를 찾아보세요.

  • 양자화 인식 훈련의 이점과 지원되는 기능을 보려면 개요를 참조하세요.
  • 단일 엔드 투 엔드 예제는 양자화 인식 훈련 예제를 참조하세요.

다음 사용 사례를 다룹니다.

  • 다음 단계에 따라 8bit 양자화로 모델을 배포합니다.
    • 양자화 인식 모델을 정의합니다.
    • Keras HDF5 모델의 경우에만 특수 체크포인트 및 역직렬화 로직을 사용합니다. 그렇지 않으면 훈련이 표준입니다.
    • 양자화 인식 모델에서 양자화 모델을 만듭니다.
  • 양자화로 실험합니다.
    • 실험용으로 지원되는 배포 경로가 없습니다.
    • 사용자 정의 Keras 레이어는 실험 중입니다.

설정

필요한 API를 찾고 목적을 이해하기 위해 실행할 수 있지만, 이 섹션은 건너뛸 수 있습니다.

! pip uninstall -y tensorflow
! pip install -q tf-nightly
! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model= setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def setup_pretrained_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  model.load_weights(pretrained_weights)
  return model

setup_model()
pretrained_weights = setup_pretrained_weights()

양자화 인식 모델을 정의합니다.

다음과 같은 방법으로 모델을 정의하면 개요 페이지에 나열된 백엔드에 배포할 수 있는 경로가 있습니다. 기본적으로 8bit 양자화가 사용됩니다.

참고: 양자화 인식 모델은 실제로 양자화되지 않습니다. 양자화된 모델을 만드는 것은 별도의 단계입니다.

전체 모델 양자화하기

사용 사례:

  • 하위 클래스화된 모델은 지원되지 않습니다.

모델 정확성의 향상을 위한 팁:

  • 정확성을 가장 많이 떨어뜨리는 레이어 양자화를 건너뛰려면 "일부 레이어 양자화"를 시도하세요.
  • 일반적으로 처음부터 훈련하는 것보다 양자화 인식 훈련으로 미세 조정하는 것이 좋습니다.

전체 모델이 양자화를 인식하도록 하려면, tfmot.quantization.keras.quantize_model을 모델에 적용합니다.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer (QuantizeLaye (None, 20)                3         
_________________________________________________________________
quant_dense_2 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_2 (QuantizeWra (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

일부 레이어 양자화하기

모델을 양자화하면 정확성에 부정적인 영향을 미칠 수 있습니다. 모델의 레이어를 선택적으로 양자화하여 정확성, 속도 및 모델 크기 간의 균형을 탐색할 수 있습니다.

사용 사례:

  • 완전히 양자화된 모델(예: EdgeTPU v1, 대부분의 DSP)에서만 잘 동작하는 백엔드에 배포하려면 "전체 모델 양자화하기"를 시도하세요.

모델 정확성의 향상을 위한 팁:

  • 일반적으로 처음부터 훈련하는 것보다 양자화 인식 훈련으로 미세 조정하는 것이 좋습니다.
  • 첫 번째 레이어 대신 이후 레이어를 양자화해보세요.
  • 중요 레이어(예: attention 메커니즘)는 양자화하지 마세요.

아래 예에서는 Dense 레이어만 양자화합니다.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `quantize_annotate_layer` to annotate that only the 
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_quantization_to_dense,
)

# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_1 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_3 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 428
Trainable params: 420
Non-trainable params: 8
_________________________________________________________________

이 예에서는 레이어 유형을 사용하여 양자화할 레이어를 결정했지만, 특정 레이어를 양자화하는 가장 쉬운 방법은 name 속성을 설정하고 clone_function에서 해당 이름을 찾는 것입니다.

print(base_model.layers[0].name)
dense_3

읽기 더 쉽지만 잠재적으로 모델 정확성이 낮음

양자화 인식 훈련을 통한 미세 조정과 호환되지 않으므로 위의 예보다 정확성이 떨어질 수 있습니다.

함수형 예

# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = tf.keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
annotated_model = tf.keras.Model(inputs=i, outputs=o)

# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
quantize_layer_2 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_4 (QuantizeWrapp (None, 10)                215       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
=================================================================
Total params: 218
Trainable params: 210
Non-trainable params: 8
_________________________________________________________________

순차 예

# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = tf.keras.Sequential([
  tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

quant_aware_model.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_3 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_5 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
=================================================================
Total params: 428
Trainable params: 420
Non-trainable params: 8
_________________________________________________________________

체크포인트 및 역직렬화

사용 사례: 이 코드는 HDF5 모델 형식(HDF5 가중치 또는 기타 형식이 아님)에만 필요합니다.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)

# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)

# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_4 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_6 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_6 (QuantizeWra (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

양자화된 모델 생성 및 배포하기

일반적으로, 사용할 배포 백엔드에 대한 설명서를 참조하세요.

다음은 TFLite 백엔드의 예입니다.

base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)

# Typically you train the model here.

converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
1/1 [==============================] - 0s 281ms/step - loss: 1.8586 - accuracy: 0.0000e+00
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:absl:Found untraced functions such as dense_7_layer_call_and_return_conditional_losses, dense_7_layer_call_fn, flatten_7_layer_call_and_return_conditional_losses, flatten_7_layer_call_fn, dense_7_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpqfedw54a/assets
INFO:tensorflow:Assets written to: /tmp/tmpqfedw54a/assets

양자화 실험하기

사용 사례: 다음 API를 사용하면 지원되는 배포 경로가 없습니다. 이들 기능은 실험적이며 이전 버전과의 호환성이 적용되지 않습니다.

  • tfmot.quantization.keras.QuantizeConfig
  • tfmot.quantization.keras.quantizers.Quantizer
  • tfmot.quantization.keras.quantizers.LastValueQuantizer
  • tfmot.quantization.keras.quantizers.MovingAverageQuantizer

설정: DefaultDenseQuantizeConfig

실험하려면 레이어의 가중치, 활성화 및 출력을 양자화하는 방법을 설명하는 tfmot.quantization.keras.QuantizeConfig를 사용해야 합니다.

아래는 API 기본값에서 Dense 레이어에 사용되는 같은 QuantizeConfig를 정의하는 예입니다.

이 예제에서 순방향 전파 중에 get_weights_and_quantizers에서 반환된 LastValueQuantizer는 입력으로 layer.kernel을 사용하여 호출되어 출력이 생성됩니다. 출력은 set_quantize_weights에 정의된 로직을 통해 Dense 레이어의 원래 순방향 전파에서 layer.kernel을 대체합니다. 같은 아이디어가 활성화 및 출력에 적용됩니다.

LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
      return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]

    def set_quantize_weights(self, layer, quantize_weights):
      # Add this line for each item returned in `get_weights_and_quantizers`
      # , in the same order
      layer.kernel = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
      # Add this line for each item returned in `get_activations_and_quantizers`
      # , in the same order.
      layer.activation = quantize_activations[0]

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
      return []

    def get_config(self):
      return {}

사용자 정의 Keras 레이어 양자화하기

이 예제에서는 DefaultDenseQuantizeConfig를 사용하여 CustomLayer를 양자화합니다.

구성 적용은 "양자화 실험하기" 사용 사례에서와 같습니다.

  • CustomLayertfmot.quantization.keras.quantize_annotate_layer를 적용하고 QuantizeConfig를 전달합니다.
  • tfmot.quantization.keras.quantize_annotate_model을 사용하여 API 기본값으로 나머지 모델을 계속 양자화합니다.
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class CustomLayer(tf.keras.layers.Dense):
  pass

model = quantize_annotate_model(tf.keras.Sequential([
   quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
  {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
   'CustomLayer': CustomLayer}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_6 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_custom_layer (Quantize (None, 20)                425       
_________________________________________________________________
quant_flatten_9 (QuantizeWra (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

양자화 매개변수 수정하기

일반적인 실수: 바이어스를 32bit 미만으로 양자화하면 일반적으로 모델 정확성이 너무 많이 손상됩니다.

이 예제에서는 기본 8bit 대신 가중치에 4bit를 사용하도록 Dense 레이어를 수정합니다. 나머지 모델은 계속해서 API 기본값을 사용합니다.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]

구성 적용은 "양자화 실험하기" 사용 사례에서와 같습니다.

  • Apply tfmot.quantization.keras.quantize_annotate_layer to the Dense layer and pass in the QuantizeConfig.
  • tfmot.quantization.keras.quantize_annotate_model을 사용하여 API 기본값으로 나머지 모델을 계속 양자화합니다.
model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_7 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_9 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_10 (QuantizeWr (None, 20)                1         
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________

양자화할 일부 레이어 수정하기

이 예제에서는 활성화 양자화를 건너뛰도록 Dense 레이어를 수정합니다. 나머지 모델은 계속해서 API 기본값을 사용합니다.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    def get_activations_and_quantizers(self, layer):
      # Skip quantizing activations.
      return []

    def set_quantize_activations(self, layer, quantize_activations):
      # Empty since `get_activaations_and_quantizers` returns
      # an empty list.
      return

구성 적용은 "양자화 실험하기" 사용 사례에서와 같습니다.

  • Dense 레이어에 tfmot.quantization.keras.quantize_annotate_layer를 적용하고 QuantizeConfig를 전달합니다.
  • tfmot.quantization.keras.quantize_annotate_model을 사용하여 API 기본값으로 나머지 모델을 계속 양자화합니다.
model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this Dense layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_8 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_10 (QuantizeWrap (None, 20)                423       
_________________________________________________________________
quant_flatten_11 (QuantizeWr (None, 20)                1         
=================================================================
Total params: 427
Trainable params: 420
Non-trainable params: 7
_________________________________________________________________

사용자 정의 양자화 알고리즘 사용하기

tfmot.quantization.keras.quantizers.Quantizer 클래스는 입력에 모든 알고리즘을 적용할 수 있는 callable입니다.

이 예에서 입력은 가중치이며 FixedRangeQuantizer call 함수의 수학을 가중치에 적용합니다. 원래 가중치 값 대신 FixedRangeQuantizer의 출력이 이제 가중치를 사용하는 모든 항목으로 전달됩니다.

quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
  """Quantizer which forces outputs to be between -1 and 1."""

  def build(self, tensor_shape, name, layer):
    # Not needed. No new TensorFlow variables needed.
    return {}

  def __call__(self, inputs, training, weights, **kwargs):
    return tf.keras.backend.clip(inputs, -1.0, 1.0)

  def get_config(self):
    # Not needed. No __init__ parameters to serialize.
    return {}


class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
    # Configure weights to quantize with 4-bit instead of 8-bits.
    def get_weights_and_quantizers(self, layer):
      # Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
      return [(layer.kernel, FixedRangeQuantizer())]

Applying the configuration is the same across the "Experiment with quantization" use cases.

  • Dense 레이어에 tfmot.quantization.keras.quantize_annotate_layer를 적용하고 QuantizeConfig를 전달합니다.
  • Use tfmot.quantization.keras.quantize_annotate_model to continue to quantize the rest of the model with the API defaults.
model = quantize_annotate_model(tf.keras.Sequential([
   # Pass in modified `QuantizeConfig` to modify this `Dense` layer.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
   tf.keras.layers.Flatten()
]))

# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()
Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer_9 (QuantizeLa (None, 20)                3         
_________________________________________________________________
quant_dense_11 (QuantizeWrap (None, 20)                423       
_________________________________________________________________
quant_flatten_12 (QuantizeWr (None, 20)                1         
=================================================================
Total params: 427
Trainable params: 420
Non-trainable params: 7
_________________________________________________________________