روز جامعه ML 9 نوامبر است! برای به روز رسانی از TensorFlow، JAX به ما بپیوندید، و بیشتر بیشتر بدانید

خوشه حفظ آموزش آگاهی کمی (CQAT) مثال Keras

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده در GitHub دانلود دفترچه یادداشت

بررسی اجمالی

این پایان دادن به عنوان مثال پایان نشان دادن استفاده از خوشه حفظ تدریج آموزش آگاه (CQAT) API، بخشی از خط لوله بهینه سازی مشارکتی TensorFlow مدل بهینه سازی جعبه ابزار است.

صفحات دیگر

برای معرفی به خط لوله و دیگر تکنیک های موجود، مشاهده صفحه مرور کلی بهینه سازی مشترک .

فهرست

در این آموزش ، شما:

  1. آموزش tf.keras مدل برای مجموعه داده MNIST از ابتدا.
  2. مدل را با خوشه بندی دقیق تنظیم کنید و دقت آن را ببینید.
  3. QAT را اعمال کنید و از بین رفتن خوشه ها را مشاهده کنید.
  4. CQAT را اعمال کنید و مشاهده کنید که خوشه بندی اعمال شده در اوایل حفظ شده است.
  5. یک مدل TFLite ایجاد کنید و تأثیرات اعمال CQAT بر روی آن را مشاهده کنید.
  6. دقت مدل CQAT بدست آمده را با مدلی که با استفاده از کمی سازی بعد از آموزش مقداردهی شده مقایسه کنید.

برپایی

شما می توانید این Jupyter نوت بوک در محلی خود را اجرا از virtualenv یا COLAB . برای جزئیات راه اندازی وابستگی، لطفا به مراجعه راهنمای نصب و راه اندازی .

 pip install -q tensorflow-model-optimization
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

بدون خوشه بندی یک مدل tf.keras را برای MNIST آموزش دهید

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728
Epoch 2/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835
<tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>

مدل پایه را ارزیابی کنید و آن را برای استفاده بعدی ذخیره کنید

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707
Saving model to:  /tmp/tmppd3opqzk.h5

مدل را با 8 خوشه خوشه بندی و تنظیم کنید

درخواست cluster_weights() API به خوشه تمام مدل های از پیش آموزش دیده برای نشان دادن و مشاهده اثر آن در کاهش اندازه مدل در هنگام استفاده از فایل های فشرده، در حالی که حفظ دقت و صحت. برای اینکه چگونه به بهترین استفاده از API برای رسیدن به بهترین میزان فشرده سازی در حالی که حفظ دقت و صحت هدف خود را، به مراجعه راهنمای جامع خوشه .

مدل را تعریف کرده و API خوشه بندی را اعمال کنید

قبل از استفاده از خوشه بندی API ، این مدل باید قبل از آموزش دیده باشد.

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_reshape (ClusterWeig (None, 28, 28, 1)         0         
_________________________________________________________________
cluster_conv2d (ClusterWeigh (None, 26, 26, 12)        236       
_________________________________________________________________
cluster_max_pooling2d (Clust (None, 13, 13, 12)        0         
_________________________________________________________________
cluster_flatten (ClusterWeig (None, 2028)              0         
_________________________________________________________________
cluster_dense (ClusterWeight (None, 10)                40578     
=================================================================
Total params: 40,814
Trainable params: 20,426
Non-trainable params: 20,388
_________________________________________________________________

مدل را دقیق تنظیم کرده و دقت را در برابر خط پایه ارزیابی کنید

مدل را با خوشه بندی برای 3 دوره تنظیم کنید.

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837
<tensorflow.python.keras.callbacks.History at 0x7f240d271990>

توابع کمکی را برای محاسبه و چاپ تعداد خوشه بندی در هر هسته از مدل تعریف کنید.

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

بررسی کنید که هسته های مدل به درستی خوشه بندی شده اند. ابتدا باید لفاف خوشه بندی را نوار بزنیم.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

برای این مثال ، در مقایسه با خط مبنا ، پس از خوشه بندی ، کمترین میزان دقت در آزمون وجود دارد.

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707
Clustered test accuracy: 0.9800999760627747

در هر دو مورد QAT و CQAT را اعمال کنید و تأثیر آن را بر خوشه های مدل بررسی کنید

بعد ، ما هم QAT و هم خوشه QAT (CQAT) را روی مدل خوشه ای اعمال می کنیم و مشاهده می کنیم که CQAT خوشه های وزنی را در مدل خوشه ای شما حفظ می کند. توجه داشته باشید که ما دستگاه فوم پیچ دور خوشه از مدل خود را با تکمیل نشده tfmot.clustering.keras.strip_clustering قبل از اعمال CQAT API.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838
<tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19931 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters

مزایای فشرده سازی مدل CQAT را ببینید

برای دریافت فایل مدل زیپ شده ، تابع کمکی را تعریف کنید.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

توجه داشته باشید که این یک مدل کوچک است. استفاده از خوشه بندی و CQAT به یک مدل تولید بزرگتر ، فشرده سازی قابل توجه تری را به همراه خواهد داشت.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
QAT model size:  16.685  KB
CQAT model size:  10.121  KB

ماندگاری دقت را از TF به TFLite مشاهده کنید

برای ارزیابی مدل TFLite در مجموعه داده های آزمون ، یک تابع کمکی تعریف کنید.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

شما مدلی را که به صورت خوشه ای و کوانتیزه شده ارزیابی می کنید و سپس دقت TensorFlow را در قسمت TFLite مشاهده می کنید.

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9795
Clustered TF test accuracy: 0.9800999760627747

کمی سازی بعد از آموزش را اعمال کنید و با مدل CQAT مقایسه کنید

در مرحله بعدی ، ما از کمیت سازی بعد از آموزش (بدون تنظیم دقیق) روی مدل خوشه ای استفاده می کنیم و صحت آن را در برابر مدل CQAT بررسی می کنیم. این نشان می دهد که چرا شما برای بهبود دقت مدل کمی نیاز به استفاده از CQAT دارید.

ابتدا از 1000 تصویر آموزشی اول یک ژنراتور برای مجموعه داده های کالیبراسیون تعریف کنید.

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

مدل را اندازه گیری کرده و دقت را با مدل CQAT که قبلاً بدست آورده اید مقایسه کنید. توجه داشته باشید که مدل کوانتیزه شده با تنظیم دقیق دقت بالاتری را به دست می آورد.

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9795
Post-training (no fine-tuning) TF test accuracy: 0.9804

نتیجه

در این آموزش، شما یاد گرفتم که چگونه برای ایجاد یک مدل، خوشه با استفاده از cluster_weights() API، و اعمال آموزش آگاه خوشه حفظ تدریج (CQAT) برای حفظ خوشه حالی که با استفاده قطر. مدل نهایی CQAT با مدل QAT مقایسه شد تا نشان دهد که خوشه ها در اولی حفظ شده و در دومی از بین رفته اند. در مرحله بعد ، مدل ها به TFLite تبدیل شدند تا مزایای فشرده سازی خوشه بندی زنجیره ای و تکنیک های بهینه سازی مدل CQAT را نشان دهند و مدل TFLite برای اطمینان از تداوم دقت در باطن TFLite مورد ارزیابی قرار گرفت. سرانجام ، مدل CQAT با یک مدل خوشه ای کوانتیزه به دست آمده با استفاده از API کوانتیزه پس از آموزش برای نشان دادن مزیت CQAT در بازیابی از دست دادن دقت از کوانتیزاسیون طبیعی مقایسه شد.