Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Ejemplo de Keras del entrenamiento consciente de la cuantificación de la preservación del clúster (CQAT)

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

Descripción general

Este es un fin al ejemplo extremo que muestra el uso de la agrupación preservar la API de cuantificación de la formación consciente (CQAT), que forma parte de la tubería de colaboración optimización del modelo de optimización de TensorFlow Toolkit.

Otras paginas

Para una introducción a la tubería y otras técnicas disponibles, consulte la página de información general de colaboración optimización .

Contenido

En el tutorial, podrá:

  1. Entrenar a un tf.keras modelo para el conjunto de datos MNIST desde cero.
  2. Ajuste el modelo con agrupamiento y vea la precisión.
  3. Aplicar QAT y observar la pérdida de clústeres.
  4. Aplique CQAT y observe que la agrupación aplicada anteriormente se ha conservado.
  5. Genere un modelo TFLite y observe los efectos de aplicar CQAT en él.
  6. Compare la precisión del modelo CQAT lograda con un modelo cuantificado utilizando la cuantificación posterior al entrenamiento.

Configuración

Puede ejecutar este Notebook Jupyter en su local de virtualenv o colab . Para los detalles de la creación de dependencias, consulte la guía de instalación .

 pip install -q tensorflow-model-optimization
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

Entrene un modelo tf.keras para MNIST sin agrupamiento

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728
Epoch 2/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835
<tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>

Evalúe el modelo de línea de base y guárdelo para su uso posterior

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707
Saving model to:  /tmp/tmppd3opqzk.h5

Agrupar y ajustar el modelo con 8 grupos

Aplicar la cluster_weights() API a agruparse todo el modelo pre-formados para demostrar y observar su eficacia en la reducción del tamaño del modelo cuando la aplicación de cremallera, manteniendo al mismo tiempo la precisión. Para la mejor manera de utilizar la API para obtener la mejor tasa de compresión, manteniendo la precisión de destino, consulte la guía completa de la agrupación .

Definir el modelo y aplicar la API de agrupación

El modelo debe entrenarse previamente antes de usar la API de agrupación en clústeres.

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_reshape (ClusterWeig (None, 28, 28, 1)         0         
_________________________________________________________________
cluster_conv2d (ClusterWeigh (None, 26, 26, 12)        236       
_________________________________________________________________
cluster_max_pooling2d (Clust (None, 13, 13, 12)        0         
_________________________________________________________________
cluster_flatten (ClusterWeig (None, 2028)              0         
_________________________________________________________________
cluster_dense (ClusterWeight (None, 10)                40578     
=================================================================
Total params: 40,814
Trainable params: 20,426
Non-trainable params: 20,388
_________________________________________________________________

Ajuste el modelo y evalúe la precisión con respecto a la línea de base

Ajuste el modelo con agrupamiento durante 3 épocas.

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837
<tensorflow.python.keras.callbacks.History at 0x7f240d271990>

Defina funciones auxiliares para calcular e imprimir el número de agrupaciones en cada núcleo del modelo.

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Compruebe que los núcleos del modelo estén agrupados correctamente. Primero tenemos que quitar la envoltura de agrupamiento.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

Para este ejemplo, hay una pérdida mínima en la precisión de la prueba después de la agrupación, en comparación con la línea de base.

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707
Clustered test accuracy: 0.9800999760627747

Aplicar QAT y CQAT y comprobar el efecto en los grupos de modelos en ambos casos

A continuación, aplicamos QAT y QAT de conservación de clústeres (CQAT) en el modelo agrupado y observamos que CQAT conserva los clústeres de ponderación en su modelo agrupado. Nótese que se desnudaron envoltorios de agrupamiento de su modelo con tfmot.clustering.keras.strip_clustering antes de aplicar API CQAT.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838
<tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19931 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters

Vea los beneficios de compresión del modelo CQAT

Defina la función auxiliar para obtener el archivo de modelo comprimido.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Tenga en cuenta que este es un modelo pequeño. La aplicación de agrupación en clústeres y CQAT a un modelo de producción más grande produciría una compresión más significativa.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
QAT model size:  16.685  KB
CQAT model size:  10.121  KB

Vea la persistencia de la precisión de TF a TFLite

Defina una función auxiliar para evaluar el modelo TFLite en el conjunto de datos de prueba.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Evalúas el modelo, que se ha agrupado y cuantificado, y luego ves que la precisión de TensorFlow persiste en el backend de TFLite.

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9795
Clustered TF test accuracy: 0.9800999760627747

Aplicar la cuantificación posterior al entrenamiento y comparar con el modelo CQAT

A continuación, usamos la cuantificación posterior al entrenamiento (sin ajuste fino) en el modelo agrupado y verificamos su precisión con el modelo CQAT. Esto demuestra por qué necesitaría utilizar CQAT para mejorar la precisión del modelo cuantificado.

Primero, defina un generador para el conjunto de datos de calibración a partir de las primeras 1000 imágenes de entrenamiento.

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

Cuantifique el modelo y compare la precisión con el modelo CQAT adquirido anteriormente. Tenga en cuenta que el modelo cuantificado con ajuste fino logra una mayor precisión.

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9795
Post-training (no fine-tuning) TF test accuracy: 0.9804

Conclusión

En este tutorial, aprendió a crear un modelo, IT Cluster usando los cluster_weights() de la API, y aplicar el cluster preservar la cuantificación de formación conscientes (CQAT) para preservar racimos durante el uso de QAT. El modelo CQAT final se comparó con el modelo QAT para mostrar que los grupos se conservan en el primero y se pierden en el segundo. A continuación, los modelos se convirtieron a TFLite para mostrar los beneficios de compresión del encadenamiento de clústeres y las técnicas de optimización del modelo CQAT y se evaluó el modelo TFLite para garantizar que la precisión persista en el backend de TFLite. Finalmente, el modelo CQAT se comparó con un modelo agrupado cuantificado logrado utilizando la API de cuantificación posterior al entrenamiento para demostrar la ventaja de CQAT en la recuperación de la pérdida de precisión de la cuantificación normal.