![]() | ![]() | ![]() | ![]() |
Keras量子化対応トレーニングの包括的なガイドへようこそ。
このページでは、さまざまなユースケースについて説明し、それぞれのAPIの使用方法を示します。必要なAPIがわかったら、 APIドキュメントでパラメータと低レベルの詳細を見つけます。
- 量子化対応トレーニングの利点とサポートされているものを確認したい場合は、概要を参照してください。
- 単一のエンドツーエンドの例については、量子化対応のトレーニング例を参照してください。
次のユースケースについて説明します。
- これらの手順で、8ビット量子化を使用してモデルをデプロイします。
- 量子化対応モデルを定義します。
- Keras HDF5モデルの場合のみ、特別なチェックポイントと逆シリアル化ロジックを使用してください。それ以外の場合、トレーニングは標準です。
- 量子化対応モデルから量子化モデルを作成します。
- 量子化を試してください。
- 実験用のものには、展開へのサポートされたパスがありません。
- カスタムKerasレイヤーは実験に含まれます。
セットアップ
必要なAPIを見つけて目的を理解するために、実行できますが、このセクションを読むことはスキップしてください。
! pip uninstall -y tensorflow
! pip install -q tf-nightly
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(20, input_shape=input_shape),
tf.keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model= setup_model()
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def setup_pretrained_model():
model = setup_model()
pretrained_weights = setup_pretrained_weights()
model.load_weights(pretrained_weights)
return model
setup_model()
pretrained_weights = setup_pretrained_weights()
量子化対応モデルを定義する
次の方法でモデルを定義することにより、概要ページにリストされているバックエンドへのデプロイメントへの利用可能なパスがあります。デフォルトでは、8ビットの量子化が使用されます。
モデル全体を定量化する
あなたのユースケース:
- サブクラス化されたモデルはサポートされていません。
モデルの精度を高めるためのヒント:
- 「いくつかのレイヤーを量子化する」を試して、精度を最も低下させるレイヤーの量子化をスキップしてください。
- 一般に、最初からトレーニングするのではなく、量子化を意識したトレーニングで微調整することをお勧めします。
モデル全体に量子化を認識させるには、 tfmot.quantization.keras.quantize_model
をモデルに適用します。
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer (QuantizeLaye (None, 20) 3 _________________________________________________________________ quant_dense_2 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ quant_flatten_2 (QuantizeWra (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
いくつかのレイヤーを定量化する
モデルの量子化は、精度に悪影響を与える可能性があります。モデルのレイヤーを選択的に量子化して、精度、速度、モデルサイズの間のトレードオフを調べることができます。
あなたのユースケース:
- 完全に量子化されたモデル(EdgeTPU v1、ほとんどのDSPなど)でのみ適切に機能するバックエンドにデプロイするには、「モデル全体の量子化」を試してください。
モデルの精度を高めるためのヒント:
- 一般に、最初からトレーニングするのではなく、量子化を意識したトレーニングで微調整することをお勧めします。
- 最初のレイヤーではなく、後のレイヤーを量子化してみてください。
- 重要なレイヤー(注意メカニズムなど)の量子化は避けてください。
以下の例では、 Dense
レイヤーのみを量子化します。
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `quantize_annotate_layer` to annotate that only the
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense`
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
base_model,
clone_function=apply_quantization_to_dense,
)
# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_1 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_3 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 428 Trainable params: 420 Non-trainable params: 8 _________________________________________________________________
この例では、レイヤーのタイプを使用して量子化する対象を決定しましたが、特定のレイヤーを量子化する最も簡単な方法は、そのname
プロパティを設定し、 clone_function
でその名前を探すことclone_function
。
print(base_model.layers[0].name)
dense_3
読みやすくなりますが、モデルの精度が低下する可能性があります
これは、量子化対応トレーニングによる微調整と互換性がないため、上記の例よりも精度が低くなる可能性があります。
機能例
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = tf.keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
annotated_model = tf.keras.Model(inputs=i, outputs=o)
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 20)] 0 _________________________________________________________________ quantize_layer_2 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_4 (QuantizeWrapp (None, 10) 215 _________________________________________________________________ flatten_4 (Flatten) (None, 10) 0 ================================================================= Total params: 218 Trainable params: 210 Non-trainable params: 8 _________________________________________________________________
順次例
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = tf.keras.Sequential([
tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),
tf.keras.layers.Flatten()
])
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_3 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_5 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 428 Trainable params: 420 Non-trainable params: 8 _________________________________________________________________
チェックポイントと逆シリアル化
ユースケース:このコードは、HDF5モデル形式にのみ必要です(HDF5ウェイトやその他の形式には必要ありません)。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)
# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
loaded_model = tf.keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_4 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_6 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ quant_flatten_6 (QuantizeWra (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
量子化モデルを作成して展開する
一般に、使用するデプロイメントバックエンドのドキュメントを参照してください。
これはTFLiteバックエンドの例です。
base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Typically you train the model here.
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
1/1 [==============================] - 0s 267ms/step - loss: 1.3382 - accuracy: 0.0000e+00 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: /tmp/tmpx2wcgxt9/assets
量子化を試す
ユースケース:次のAPIを使用するということは、デプロイへのサポートされているパスがないことを意味します。これらの機能も実験的なものであり、下位互換性の影響を受けません。
-
tfmot.quantization.keras.QuantizeConfig
-
tfmot.quantization.keras.quantizers.Quantizer
-
tfmot.quantization.keras.quantizers.LastValueQuantizer
-
tfmot.quantization.keras.quantizers.MovingAverageQuantizer
セットアップ:DefaultDenseQuantizeConfig
実験するには、 tfmot.quantization.keras.QuantizeConfig
を使用する必要があります。これは、レイヤーの重み、アクティブ化、および出力を量子化する方法を説明しています。
以下は、APIのデフォルトでDense
レイヤーに使用されるのと同じQuantizeConfig
を定義する例です。
この例の順伝播中に、 LastValueQuantizer
返されるget_weights_and_quantizers
が、 layer.kernel
を入力としてlayer.kernel
れ、出力が生成されます。出力は、 layer.kernel
で定義されたロジックを介して、 Dense
レイヤーの元の順方向伝播のset_quantize_weights
ます。同じ考え方がアクティベーションと出力にも当てはまります。
LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer
class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
# Configure how to quantize weights.
def get_weights_and_quantizers(self, layer):
return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]
# Configure how to quantize activations.
def get_activations_and_quantizers(self, layer):
return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]
def set_quantize_weights(self, layer, quantize_weights):
# Add this line for each item returned in `get_weights_and_quantizers`
# , in the same order
layer.kernel = quantize_weights[0]
def set_quantize_activations(self, layer, quantize_activations):
# Add this line for each item returned in `get_activations_and_quantizers`
# , in the same order.
layer.activation = quantize_activations[0]
# Configure how to quantize outputs (may be equivalent to activations).
def get_output_quantizers(self, layer):
return []
def get_config(self):
return {}
カスタムKerasレイヤーを定量化する
この例では、 DefaultDenseQuantizeConfig
を使用してCustomLayer
をクオンタイズしCustomLayer
。
構成の適用は、「量子化を使用した実験」のユースケース全体で同じです。
- 適用
tfmot.quantization.keras.quantize_annotate_layer
にCustomLayer
と渡しQuantizeConfig
。 -
tfmot.quantization.keras.quantize_annotate_model
を使用して、APIのデフォルトでモデルの残りの部分を引き続き量子化します。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class CustomLayer(tf.keras.layers.Dense):
pass
model = quantize_annotate_model(tf.keras.Sequential([
quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
{'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
'CustomLayer': CustomLayer}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_6 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_custom_layer (Quantize (None, 20) 425 _________________________________________________________________ quant_flatten_9 (QuantizeWra (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
量子化パラメータを変更する
よくある間違い:バイアスを32ビット未満に量子化すると、通常、モデルの精度が大幅に低下します。
この例では、デフォルトの8ビットではなく4ビットを重みに使用するようにDense
レイヤーを変更します。モデルの残りの部分は、引き続きAPIのデフォルトを使用します。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
# Configure weights to quantize with 4-bit instead of 8-bits.
def get_weights_and_quantizers(self, layer):
return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]
構成の適用は、「量子化を使用した実験」のユースケース全体で同じです。
-
tfmot.quantization.keras.quantize_annotate_layer
をDense
レイヤーに適用し、QuantizeConfig
を渡します。 -
tfmot.quantization.keras.quantize_annotate_model
を使用して、APIのデフォルトでモデルの残りの部分を引き続き量子化します。
model = quantize_annotate_model(tf.keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this Dense layer.
quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_9" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_7 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_9 (QuantizeWrapp (None, 20) 425 _________________________________________________________________ quant_flatten_10 (QuantizeWr (None, 20) 1 ================================================================= Total params: 429 Trainable params: 420 Non-trainable params: 9 _________________________________________________________________
レイヤーの一部を変更してクオンタイズします
この例では、 Dense
レイヤーを変更して、アクティベーションの量子化をスキップします。モデルの残りの部分は、引き続きAPIのデフォルトを使用します。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
def get_activations_and_quantizers(self, layer):
# Skip quantizing activations.
return []
def set_quantize_activations(self, layer, quantize_activations):
# Empty since `get_activaations_and_quantizers` returns
# an empty list.
return
構成の適用は、「量子化を使用した実験」のユースケース全体で同じです。
-
tfmot.quantization.keras.quantize_annotate_layer
をDense
レイヤーに適用し、QuantizeConfig
を渡します。 -
tfmot.quantization.keras.quantize_annotate_model
を使用して、APIのデフォルトでモデルの残りの部分を引き続き量子化します。
model = quantize_annotate_model(tf.keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this Dense layer.
quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_10" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_8 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_10 (QuantizeWrap (None, 20) 423 _________________________________________________________________ quant_flatten_11 (QuantizeWr (None, 20) 1 ================================================================= Total params: 427 Trainable params: 420 Non-trainable params: 7 _________________________________________________________________
カスタム量子化アルゴリズムを使用する
tfmot.quantization.keras.quantizers.Quantizer
クラスは、任意のアルゴリズムを入力に適用できる呼び出し可能オブジェクトです。
この例では、入力は重みであり、 FixedRangeQuantizer
関数の計算を重みに適用します。元の重み値の代わりに、 FixedRangeQuantizer
の出力が重みを使用したものに渡されるようになりました。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
"""Quantizer which forces outputs to be between -1 and 1."""
def build(self, tensor_shape, name, layer):
# Not needed. No new TensorFlow variables needed.
return {}
def __call__(self, inputs, training, weights, **kwargs):
return tf.keras.backend.clip(inputs, -1.0, 1.0)
def get_config(self):
# Not needed. No __init__ parameters to serialize.
return {}
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
# Configure weights to quantize with 4-bit instead of 8-bits.
def get_weights_and_quantizers(self, layer):
# Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
return [(layer.kernel, FixedRangeQuantizer())]
構成の適用は、「量子化を使用した実験」のユースケース全体で同じです。
-
tfmot.quantization.keras.quantize_annotate_layer
をDense
レイヤーに適用し、QuantizeConfig
を渡します。 -
tfmot.quantization.keras.quantize_annotate_model
を使用して、APIのデフォルトでモデルの残りの部分を引き続き量子化します。
model = quantize_annotate_model(tf.keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this `Dense` layer.
quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
tf.keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_11" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_9 (QuantizeLa (None, 20) 3 _________________________________________________________________ quant_dense_11 (QuantizeWrap (None, 20) 423 _________________________________________________________________ quant_flatten_12 (QuantizeWr (None, 20) 1 ================================================================= Total params: 427 Trainable params: 420 Non-trainable params: 7 _________________________________________________________________