![]() | ![]() | ![]() | ![]() |
Descripción general
Este es un fin al ejemplo extremo que muestra el uso de la agrupación preservar la API de cuantificación de la formación consciente (CQAT), que forma parte de la tubería de colaboración optimización del modelo de optimización de TensorFlow Toolkit.
Otras paginas
Para una introducción a la tubería y otras técnicas disponibles, consulte la página de información general de colaboración optimización .
Contenido
En el tutorial, podrá:
- Entrenar a un
tf.keras
modelo para el conjunto de datos MNIST desde cero. - Ajuste el modelo con agrupamiento y vea la precisión.
- Aplicar QAT y observar la pérdida de clústeres.
- Aplique CQAT y observe que la agrupación aplicada anteriormente se ha conservado.
- Genere un modelo TFLite y observe los efectos de aplicar CQAT en él.
- Compare la precisión del modelo CQAT lograda con un modelo cuantificado utilizando la cuantificación posterior al entrenamiento.
Configuración
Puede ejecutar este Notebook Jupyter en su local de virtualenv o colab . Para los detalles de la creación de dependencias, consulte la guía de instalación .
pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tempfile
import zipfile
import os
Entrene un modelo tf.keras para MNIST sin agrupamiento
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step Epoch 1/10 1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728 Epoch 2/10 1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793 Epoch 3/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823 Epoch 4/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840 Epoch 5/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837 Epoch 6/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838 Epoch 7/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827 Epoch 8/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847 Epoch 9/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847 Epoch 10/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835 <tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>
Evalúe el modelo de línea de base y guárdelo para su uso posterior
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707 Saving model to: /tmp/tmppd3opqzk.h5
Agrupar y ajustar el modelo con 8 grupos
Aplicar la cluster_weights()
API a agruparse todo el modelo pre-formados para demostrar y observar su eficacia en la reducción del tamaño del modelo cuando la aplicación de cremallera, manteniendo al mismo tiempo la precisión. Para la mejor manera de utilizar la API para obtener la mejor tasa de compresión, manteniendo la precisión de destino, consulte la guía completa de la agrupación .
Definir el modelo y aplicar la API de agrupación
El modelo debe entrenarse previamente antes de usar la API de agrupación en clústeres.
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}
clustered_model = cluster_weights(model, **clustering_params)
# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
clustered_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version. Instructions for updating: The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU. Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_reshape (ClusterWeig (None, 28, 28, 1) 0 _________________________________________________________________ cluster_conv2d (ClusterWeigh (None, 26, 26, 12) 236 _________________________________________________________________ cluster_max_pooling2d (Clust (None, 13, 13, 12) 0 _________________________________________________________________ cluster_flatten (ClusterWeig (None, 2028) 0 _________________________________________________________________ cluster_dense (ClusterWeight (None, 10) 40578 ================================================================= Total params: 40,814 Trainable params: 20,426 Non-trainable params: 20,388 _________________________________________________________________
Ajuste el modelo y evalúe la precisión con respecto a la línea de base
Ajuste el modelo con agrupamiento durante 3 épocas.
# Fine-tune model
clustered_model.fit(
train_images,
train_labels,
epochs=3,
validation_split=0.1)
Epoch 1/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830 Epoch 2/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830 Epoch 3/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837 <tensorflow.python.keras.callbacks.History at 0x7f240d271990>
Defina funciones auxiliares para calcular e imprimir el número de agrupaciones en cada núcleo del modelo.
def print_model_weight_clusters(model):
for layer in model.layers:
if isinstance(layer, tf.keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
# ignore auxiliary quantization weights
if "quantize_layer" in weight.name:
continue
if "kernel" in weight.name:
unique_count = len(np.unique(weight))
print(
f"{layer.name}/{weight.name}: {unique_count} clusters "
)
Compruebe que los núcleos del modelo estén agrupados correctamente. Primero tenemos que quitar la envoltura de agrupamiento.
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)
print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters dense/kernel:0: 8 clusters
Para este ejemplo, hay una pérdida mínima en la precisión de la prueba después de la agrupación, en comparación con la línea de base.
_, clustered_model_accuracy = clustered_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707 Clustered test accuracy: 0.9800999760627747
Aplicar QAT y CQAT y comprobar el efecto en los grupos de modelos en ambos casos
A continuación, aplicamos QAT y QAT de conservación de clústeres (CQAT) en el modelo agrupado y observamos que CQAT conserva los clústeres de ponderación en su modelo agrupado. Nótese que se desnudaron envoltorios de agrupamiento de su modelo con tfmot.clustering.keras.strip_clustering
antes de aplicar API CQAT.
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)
qat_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
quant_aware_annotate_model,
tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())
cqat_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model: 422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840 WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. Train cqat model: WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. 422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838 <tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters: quant_conv2d/conv2d/kernel:0: 108 clusters quant_dense/dense/kernel:0: 19931 clusters CQAT Model clusters: quant_conv2d/conv2d/kernel:0: 8 clusters quant_dense/dense/kernel:0: 8 clusters
Vea los beneficios de compresión del modelo CQAT
Defina la función auxiliar para obtener el archivo de modelo comprimido.
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in kilobytes.
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)/1000
Tenga en cuenta que este es un modelo pequeño. La aplicación de agrupación en clústeres y CQAT a un modelo de producción más grande produciría una compresión más significativa.
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
f.write(qat_tflite_model)
# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
f.write(cqat_tflite_model)
print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets QAT model size: 16.685 KB CQAT model size: 10.121 KB
Vea la persistencia de la precisión de TF a TFLite
Defina una función auxiliar para evaluar el modelo TFLite en el conjunto de datos de prueba.
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print(f"Evaluated on {i} results so far.")
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
Evalúas el modelo, que se ha agrupado y cuantificado, y luego ves que la precisión de TensorFlow persiste en el backend de TFLite.
interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()
cqat_test_accuracy = eval_model(interpreter)
print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Clustered and quantized TFLite test_accuracy: 0.9795 Clustered TF test accuracy: 0.9800999760627747
Aplicar la cuantificación posterior al entrenamiento y comparar con el modelo CQAT
A continuación, usamos la cuantificación posterior al entrenamiento (sin ajuste fino) en el modelo agrupado y verificamos su precisión con el modelo CQAT. Esto demuestra por qué necesitaría utilizar CQAT para mejorar la precisión del modelo cuantificado.
Primero, defina un generador para el conjunto de datos de calibración a partir de las primeras 1000 imágenes de entrenamiento.
def mnist_representative_data_gen():
for image in train_images[:1000]:
image = np.expand_dims(image, axis=0).astype(np.float32)
yield [image]
Cuantifique el modelo y compare la precisión con el modelo CQAT adquirido anteriormente. Tenga en cuenta que el modelo cuantificado con ajuste fino logra una mayor precisión.
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
f.write(post_training_tflite_model)
# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()
post_training_test_accuracy = eval_model(interpreter)
print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. CQAT TFLite test_accuracy: 0.9795 Post-training (no fine-tuning) TF test accuracy: 0.9804
Conclusión
En este tutorial, aprendió a crear un modelo, IT Cluster usando los cluster_weights()
de la API, y aplicar el cluster preservar la cuantificación de formación conscientes (CQAT) para preservar racimos durante el uso de QAT. El modelo CQAT final se comparó con el modelo QAT para mostrar que los grupos se conservan en el primero y se pierden en el segundo. A continuación, los modelos se convirtieron a TFLite para mostrar los beneficios de compresión del encadenamiento de clústeres y las técnicas de optimización del modelo CQAT y se evaluó el modelo TFLite para garantizar que la precisión persista en el backend de TFLite. Finalmente, el modelo CQAT se comparó con un modelo agrupado cuantificado logrado utilizando la API de cuantificación posterior al entrenamiento para demostrar la ventaja de CQAT en la recuperación de la pérdida de precisión de la cuantificación normal.