Contoh Keras pelatihan sadar kuantisasi (CQAT) melestarikan klaster

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan

Gambaran

Ini adalah mengakhiri contoh akhir yang menunjukkan penggunaan cluster melestarikan pelatihan sadar (CQAT) API kuantisasi, bagian dari pipa optimasi kolaboratif yang TensorFlow Model Optimasi Toolkit.

halaman lain

Untuk pengenalan pipa dan teknik lain yang tersedia, lihat kolaboratif halaman ikhtisar optimasi .

Isi

Dalam tutorial, Anda akan:

  1. Melatih tf.keras model untuk dataset MNIST dari awal.
  2. Sempurnakan model dengan pengelompokan dan lihat akurasinya.
  3. Terapkan QAT dan amati hilangnya cluster.
  4. Terapkan CQAT dan amati bahwa pengelompokan yang diterapkan sebelumnya telah dipertahankan.
  5. Hasilkan model TFLite dan amati efek penerapan CQAT padanya.
  6. Bandingkan akurasi model CQAT yang dicapai dengan model yang dikuantisasi menggunakan kuantisasi pasca-pelatihan.

Mendirikan

Anda dapat menjalankan Notebook Jupyter ini di lokal Anda virtualenv atau colab . Untuk rincian pengaturan dependensi, silakan merujuk ke panduan instalasi .

 pip install -q tensorflow-model-optimization
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

Latih model tf.keras untuk MNIST tanpa pengelompokan

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728
Epoch 2/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835
<tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>

Evaluasi model dasar dan simpan untuk digunakan nanti

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707
Saving model to:  /tmp/tmppd3opqzk.h5

Cluster dan sempurnakan model dengan 8 cluster

Terapkan cluster_weights() API untuk cluster seluruh model pra-dilatih untuk menunjukkan dan mengamati efektivitasnya dalam mengurangi ukuran model yang saat melamar zip, sambil mempertahankan akurasi. Untuk cara terbaik untuk menggunakan API untuk mencapai tingkat kompresi terbaik dengan tetap menjaga akurasi target Anda, lihat panduan yang komprehensif klastering .

Tentukan model dan terapkan API pengelompokan

Model perlu dilatih terlebih dahulu sebelum menggunakan API pengelompokan.

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_reshape (ClusterWeig (None, 28, 28, 1)         0         
_________________________________________________________________
cluster_conv2d (ClusterWeigh (None, 26, 26, 12)        236       
_________________________________________________________________
cluster_max_pooling2d (Clust (None, 13, 13, 12)        0         
_________________________________________________________________
cluster_flatten (ClusterWeig (None, 2028)              0         
_________________________________________________________________
cluster_dense (ClusterWeight (None, 10)                40578     
=================================================================
Total params: 40,814
Trainable params: 20,426
Non-trainable params: 20,388
_________________________________________________________________

Sempurnakan model dan evaluasi akurasi terhadap baseline

Sempurnakan model dengan pengelompokan selama 3 epoch.

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837
<tensorflow.python.keras.callbacks.History at 0x7f240d271990>

Tentukan fungsi pembantu untuk menghitung dan mencetak jumlah pengelompokan di setiap kernel model.

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Periksa apakah kernel model telah dikelompokkan dengan benar. Kita perlu mengupas pembungkus pengelompokan terlebih dahulu.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

Untuk contoh ini, ada kerugian minimal dalam akurasi pengujian setelah pengelompokan, dibandingkan dengan baseline.

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707
Clustered test accuracy: 0.9800999760627747

Terapkan QAT dan CQAT dan periksa efek pada kluster model dalam kedua kasus

Selanjutnya, kami menerapkan QAT dan cluster melestarikan QAT (CQAT) pada model berkerumun dan mengamati bahwa CQAT mempertahankan gugus bobot dalam model berkerumun Anda. Perhatikan bahwa kita dilucuti mengelompokkan pembungkus dari model Anda dengan tfmot.clustering.keras.strip_clustering sebelum menerapkan CQAT API.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838
<tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19931 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters

Lihat manfaat kompresi model CQAT

Tentukan fungsi pembantu untuk mendapatkan file model zip.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Perhatikan bahwa ini adalah model kecil. Menerapkan clustering dan CQAT ke model produksi yang lebih besar akan menghasilkan kompresi yang lebih signifikan.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
QAT model size:  16.685  KB
CQAT model size:  10.121  KB

Lihat kegigihan akurasi dari TF ke TFLite

Tentukan fungsi pembantu untuk mengevaluasi model TFLite pada dataset uji.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Anda mengevaluasi model, yang telah dikelompokkan dan dikuantisasi, lalu melihat keakuratan dari TensorFlow tetap ada di backend TFLite.

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9795
Clustered TF test accuracy: 0.9800999760627747

Terapkan kuantisasi pasca pelatihan dan bandingkan dengan model CQAT

Selanjutnya, kami menggunakan kuantisasi pasca-pelatihan (tanpa fine-tuning) pada model klaster dan memeriksa akurasinya terhadap model CQAT. Ini menunjukkan mengapa Anda perlu menggunakan CQAT untuk meningkatkan akurasi model terkuantisasi.

Pertama, tentukan generator untuk set data kalibrasi dari 1000 gambar pelatihan pertama.

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

Kuantisasi model dan bandingkan akurasi dengan model CQAT yang diperoleh sebelumnya. Perhatikan bahwa model yang dikuantisasi dengan fine-tuning mencapai akurasi yang lebih tinggi.

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9795
Post-training (no fine-tuning) TF test accuracy: 0.9804

Kesimpulan

Dalam tutorial ini, Anda belajar bagaimana untuk membuat model, cluster itu menggunakan cluster_weights() API, dan menerapkan cluster melestarikan kuantisasi pelatihan sadar (CQAT) untuk melestarikan cluster saat menggunakan QAT. Model CQAT terakhir dibandingkan dengan model QAT untuk menunjukkan bahwa kluster dipertahankan pada model pertama dan hilang pada model terakhir. Selanjutnya, model dikonversi ke TFLite untuk menunjukkan manfaat kompresi dari pengelompokan rantai dan teknik optimisasi model CQAT dan model TFLite dievaluasi untuk memastikan bahwa akurasi tetap ada di backend TFLite. Akhirnya, model CQAT dibandingkan dengan model cluster terkuantisasi yang dicapai menggunakan API kuantisasi pasca-pelatihan untuk menunjukkan keuntungan CQAT dalam memulihkan kehilangan akurasi dari kuantisasi normal.