![]() | ![]() | ![]() | ![]() |
Обзор
Это конец конца пример , показывающий использование кластера сохраняющих квантования в курсе обучения (CQAT) API, часть совместного трубопровода оптимизации TensorFlow модели оптимизации инструментария.
Другие страницы
Для введения в трубопровод и другие доступных методы, см совместной страница обзора оптимизации .
СОДЕРЖАНИЕ
В этом руководстве вы:
- Поезд
tf.keras
модель для MNIST набора данных с нуля. - Настройте модель с помощью кластеризации и посмотрите точность.
- Примените QAT и наблюдайте потерю кластеров.
- Примените CQAT и обратите внимание, что примененная ранее кластеризация была сохранена.
- Создайте модель TFLite и посмотрите, как на нее влияет применение CQAT.
- Сравните достигнутую точность модели CQAT с моделью, квантованной с использованием посттренировочного квантования.
Настраивать
Вы можете запустить этот Jupyter ноутбук в вашем virtualenv или colab . Для получения дополнительной информации о настройке зависимостей, пожалуйста , обратитесь к руководству по установке .
pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tempfile
import zipfile
import os
Обучить модель tf.keras для MNIST без кластеризации
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step Epoch 1/10 1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728 Epoch 2/10 1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793 Epoch 3/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823 Epoch 4/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840 Epoch 5/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837 Epoch 6/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838 Epoch 7/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827 Epoch 8/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847 Epoch 9/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847 Epoch 10/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835 <tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>
Оцените базовую модель и сохраните ее для дальнейшего использования.
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707 Saving model to: /tmp/tmppd3opqzk.h5
Кластеризация и точная настройка модели с 8 кластерами
Нанести cluster_weights()
API группироваться все заранее подготовленные моделями для демонстрации и наблюдать за его эффективность в уменьшении размера модели при применении почтового индекса, сохраняя при этом точности. О том , как лучше использовать API для достижения наилучшего коэффициента сжатия при сохранении вашей целевой точности, обратитесь к кластеризации всеобъемлющего руководству .
Определите модель и примените API кластеризации
Перед использованием API кластеризации модель необходимо предварительно обучить.
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}
clustered_model = cluster_weights(model, **clustering_params)
# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
clustered_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version. Instructions for updating: The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU. Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_reshape (ClusterWeig (None, 28, 28, 1) 0 _________________________________________________________________ cluster_conv2d (ClusterWeigh (None, 26, 26, 12) 236 _________________________________________________________________ cluster_max_pooling2d (Clust (None, 13, 13, 12) 0 _________________________________________________________________ cluster_flatten (ClusterWeig (None, 2028) 0 _________________________________________________________________ cluster_dense (ClusterWeight (None, 10) 40578 ================================================================= Total params: 40,814 Trainable params: 20,426 Non-trainable params: 20,388 _________________________________________________________________
Выполните точную настройку модели и оцените точность по сравнению с базовой линией.
Точная настройка модели с кластеризацией для 3 эпох.
# Fine-tune model
clustered_model.fit(
train_images,
train_labels,
epochs=3,
validation_split=0.1)
Epoch 1/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830 Epoch 2/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830 Epoch 3/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837 <tensorflow.python.keras.callbacks.History at 0x7f240d271990>
Определите вспомогательные функции, чтобы вычислить и распечатать количество кластеров в каждом ядре модели.
def print_model_weight_clusters(model):
for layer in model.layers:
if isinstance(layer, tf.keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
# ignore auxiliary quantization weights
if "quantize_layer" in weight.name:
continue
if "kernel" in weight.name:
unique_count = len(np.unique(weight))
print(
f"{layer.name}/{weight.name}: {unique_count} clusters "
)
Убедитесь, что ядра модели были правильно сгруппированы. Сначала нам нужно удалить оболочку кластеризации.
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)
print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters dense/kernel:0: 8 clusters
В этом примере потеря точности теста после кластеризации минимальна по сравнению с базовой линией.
_, clustered_model_accuracy = clustered_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707 Clustered test accuracy: 0.9800999760627747
Примените QAT и CQAT и проверьте эффект на кластеры модели в обоих случаях.
Затем мы применяем QAT и QAT с сохранением кластера (CQAT) к кластерной модели и наблюдаем, что CQAT сохраняет кластеры веса в вашей кластерной модели. Обратите внимание , что мы раздели кластеризацию оберток от модели с tfmot.clustering.keras.strip_clustering
перед применением CQAT API.
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)
qat_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
quant_aware_annotate_model,
tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())
cqat_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model: 422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840 WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. Train cqat model: WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. 422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838 <tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters: quant_conv2d/conv2d/kernel:0: 108 clusters quant_dense/dense/kernel:0: 19931 clusters CQAT Model clusters: quant_conv2d/conv2d/kernel:0: 8 clusters quant_dense/dense/kernel:0: 8 clusters
Ознакомьтесь с преимуществами сжатия модели CQAT
Определите вспомогательную функцию для получения заархивированного файла модели.
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in kilobytes.
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)/1000
Обратите внимание, что это небольшая модель. Применение кластеризации и CQAT к более крупной производственной модели приведет к более значительному сжатию.
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
f.write(qat_tflite_model)
# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
f.write(cqat_tflite_model)
print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets QAT model size: 16.685 KB CQAT model size: 10.121 KB
Посмотрите на постоянство точности при переходе от TF к TFLite
Определите вспомогательную функцию для оценки модели TFLite в тестовом наборе данных.
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print(f"Evaluated on {i} results so far.")
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
Вы оцениваете модель, которая была кластеризована и квантована, а затем видите, что точность TensorFlow сохраняется в бэкэнде TFLite.
interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()
cqat_test_accuracy = eval_model(interpreter)
print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Clustered and quantized TFLite test_accuracy: 0.9795 Clustered TF test accuracy: 0.9800999760627747
Примените квантование после обучения и сравните с моделью CQAT
Затем мы используем квантование после обучения (без точной настройки) для кластерной модели и проверяем ее точность по модели CQAT. Это демонстрирует, почему вам нужно использовать CQAT для повышения точности квантованной модели.
Сначала определите генератор для набора данных калибровки из первых 1000 обучающих изображений.
def mnist_representative_data_gen():
for image in train_images[:1000]:
image = np.expand_dims(image, axis=0).astype(np.float32)
yield [image]
Проведите квантование модели и сравните точность с ранее полученной моделью CQAT. Обратите внимание, что модель, квантованная с помощью точной настройки, обеспечивает более высокую точность.
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
f.write(post_training_tflite_model)
# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()
post_training_test_accuracy = eval_model(interpreter)
print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. CQAT TFLite test_accuracy: 0.9795 Post-training (no fine-tuning) TF test accuracy: 0.9804
Вывод
В этом уроке вы узнали , как создать модель, кластер его с помощью cluster_weights()
API, и применять кластерное консервирование квантования осведомленного обучение (CQAT) для сохранения кластеров при использовании ИАКА. Окончательная модель CQAT сравнивалась с моделью QAT, чтобы показать, что кластеры сохраняются в первой и теряются во второй. Затем модели были преобразованы в TFLite, чтобы продемонстрировать преимущества сжатия цепной кластеризации и методов оптимизации модели CQAT, а модель TFLite была оценена, чтобы гарантировать, что точность сохраняется в бэкэнде TFLite. Наконец, модель CQAT сравнивалась с квантованной кластерной моделью, полученной с использованием API квантования после обучения, чтобы продемонстрировать преимущество CQAT в восстановлении потери точности из-за нормального квантования.