Google I/O สำเร็จแล้ว! ติดตามเซสชัน TensorFlow ดูเซสชัน

คลัสเตอร์การคงไว้ซึ่งการฝึกอบรมการรับรู้เชิงปริมาณ (CQAT) Keras ตัวอย่าง

ดูบน TensorFlow.org ทำงานใน Google Colab ดูบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ภาพรวม

นี่คือจุดสิ้นสุดตัวอย่างสิ้นการแสดงการใช้งานของคลัสเตอร์รักษา quantization ฝึกอบรมตระหนักถึง (CQAT) API ซึ่งเป็นส่วนหนึ่งของท่อเพิ่มประสิทธิภาพการทำงานร่วมกัน TensorFlow รุ่นการเพิ่มประสิทธิภาพเครื่องมือฯ

หน้าอื่นๆ

สำหรับคำแนะนำเกี่ยวกับท่อและเทคนิคอื่น ๆ ที่มีให้ดูที่ การเพิ่มประสิทธิภาพหน้าภาพรวมการทำงานร่วมกัน

สารบัญ

ในบทช่วยสอน คุณจะ:

  1. รถไฟ tf.keras แบบจำลองสำหรับชุดข้อมูล MNIST จากรอยขีดข่วน
  2. ปรับแต่งโมเดลด้วยคลัสเตอร์และดูความแม่นยำ
  3. ใช้ QAT และสังเกตการสูญเสียคลัสเตอร์
  4. ใช้ CQAT และสังเกตว่าคลัสเตอร์ที่ใช้ก่อนหน้านี้ได้รับการเก็บรักษาไว้
  5. สร้างโมเดล TFLite และสังเกตผลกระทบของการใช้ CQAT กับโมเดล
  6. เปรียบเทียบความแม่นยำของแบบจำลอง CQAT ที่ทำได้กับแบบจำลองเชิงปริมาณโดยใช้การหาปริมาณหลังการฝึก

ติดตั้ง

คุณสามารถทำงานนี้โน๊ตบุ๊ค Jupyter ในท้องถิ่นของคุณ virtualenv หรือ Colab สำหรับรายละเอียดของการตั้งค่าการอ้างอิงโปรดดูที่ คู่มือการติดตั้ง

 pip install -q tensorflow-model-optimization
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

ฝึกโมเดล tf.keras สำหรับ MNIST โดยไม่ต้องทำคลัสเตอร์

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728
Epoch 2/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835
<tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>

ประเมินแบบจำลองพื้นฐานและบันทึกไว้เพื่อใช้ในภายหลัง

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707
Saving model to:  /tmp/tmppd3opqzk.h5

คลัสเตอร์และปรับแต่งโมเดลด้วย 8 คลัสเตอร์

สมัคร cluster_weights() API เพื่อคลัสเตอร์รุ่นก่อนได้รับการฝึกฝนทั้งแสดงให้เห็นและสังเกตประสิทธิภาพในการลดขนาดรูปแบบเมื่อใช้ซิปขณะที่การรักษาความถูกต้อง สำหรับวิธีการที่ดีที่สุดที่จะใช้ API เพื่อให้บรรลุอัตราการบีบอัดที่ดีที่สุดในขณะที่รักษาความถูกต้องเป้าหมายของคุณให้ดูที่ คู่มือที่ครอบคลุมการจัดกลุ่ม

กำหนดโมเดลและใช้ API การทำคลัสเตอร์

โมเดลต้องได้รับการฝึกอบรมล่วงหน้าก่อนที่จะใช้คลัสเตอร์ API

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_reshape (ClusterWeig (None, 28, 28, 1)         0         
_________________________________________________________________
cluster_conv2d (ClusterWeigh (None, 26, 26, 12)        236       
_________________________________________________________________
cluster_max_pooling2d (Clust (None, 13, 13, 12)        0         
_________________________________________________________________
cluster_flatten (ClusterWeig (None, 2028)              0         
_________________________________________________________________
cluster_dense (ClusterWeight (None, 10)                40578     
=================================================================
Total params: 40,814
Trainable params: 20,426
Non-trainable params: 20,388
_________________________________________________________________

ปรับแต่งโมเดลอย่างละเอียดและประเมินความถูกต้องเทียบกับค่าพื้นฐาน

ปรับแต่งโมเดลอย่างละเอียดด้วยคลัสเตอร์ 3 ยุค

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837
<tensorflow.python.keras.callbacks.History at 0x7f240d271990>

กำหนดฟังก์ชันตัวช่วยเพื่อคำนวณและพิมพ์จำนวนการจัดกลุ่มในแต่ละเคอร์เนลของโมเดล

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

ตรวจสอบว่าเคอร์เนลรุ่นถูกจัดกลุ่มอย่างถูกต้อง เราต้องลอกกระดาษห่อหุ้มคลัสเตอร์ออกก่อน

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

สำหรับตัวอย่างนี้ มีการสูญเสียความแม่นยำในการทดสอบเพียงเล็กน้อยหลังการจัดกลุ่ม เมื่อเทียบกับค่าพื้นฐาน

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707
Clustered test accuracy: 0.9800999760627747

ใช้ QAT และ CQAT และตรวจสอบผลกระทบกับคลัสเตอร์โมเดลในทั้งสองกรณี

ต่อไป เราใช้ทั้ง QAT และการรักษาคลัสเตอร์ QAT (CQAT) กับโมเดลคลัสเตอร์ และสังเกตว่า CQAT รักษาคลัสเตอร์น้ำหนักในโมเดลคลัสเตอร์ของคุณ โปรดทราบว่าเราปล้น clustering ห่อจากแบบจำลองของคุณด้วย tfmot.clustering.keras.strip_clustering ก่อนที่จะใช้ CQAT API

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838
<tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19931 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters

ดูประโยชน์ของการบีบอัดของโมเดล CQAT

กำหนดฟังก์ชันตัวช่วยเพื่อรับไฟล์โมเดลที่บีบอัด

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

โปรดทราบว่านี่เป็นรุ่นเล็ก การใช้คลัสเตอร์และ CQAT กับโมเดลการผลิตที่ใหญ่ขึ้นจะทำให้มีการบีบอัดข้อมูลที่มีนัยสำคัญมากขึ้น

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
QAT model size:  16.685  KB
CQAT model size:  10.121  KB

ดูความคงอยู่ของความแม่นยำตั้งแต่ TF ถึง TFLite

กำหนดฟังก์ชันตัวช่วยเพื่อประเมินโมเดล TFLite บนชุดข้อมูลทดสอบ

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

คุณประเมินแบบจำลอง ซึ่งได้รับการจัดกลุ่มและวัดปริมาณ แล้วเห็นความถูกต้องจาก TensorFlow ยังคงอยู่ในแบ็กเอนด์ TFLite

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9795
Clustered TF test accuracy: 0.9800999760627747

ใช้การหาปริมาณหลังการฝึกและเปรียบเทียบกับแบบจำลอง CQAT

ต่อไป เราใช้การหาปริมาณหลังการฝึก (ไม่มีการปรับละเอียด) กับโมเดลคลัสเตอร์ และตรวจสอบความถูกต้องกับโมเดล CQAT นี่แสดงให้เห็นว่าเหตุใดคุณจึงต้องใช้ CQAT เพื่อปรับปรุงความแม่นยำของแบบจำลองเชิงปริมาณ

ขั้นแรก กำหนดตัวสร้างสำหรับชุดข้อมูลการปรับเทียบจากอิมเมจการฝึก 1,000 ภาพแรก

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

หาปริมาณของแบบจำลองและเปรียบเทียบความถูกต้องกับแบบจำลอง CQAT ที่ได้มาก่อนหน้านี้ โปรดทราบว่าโมเดลที่วัดปริมาณด้วยการปรับแต่งแบบละเอียดจะให้ความแม่นยำสูงขึ้น

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9795
Post-training (no fine-tuning) TF test accuracy: 0.9804

บทสรุป

ในการกวดวิชานี้คุณได้เรียนรู้วิธีการสร้างรูปแบบกลุ่มโดยใช้ cluster_weights() API และใช้คลัสเตอร์รักษา quantization ฝึกอบรมตระหนักถึง (CQAT) เพื่อรักษากลุ่มในขณะที่ใช้ QAT โมเดล CQAT สุดท้ายถูกเปรียบเทียบกับ QAT เพื่อแสดงว่าคลัสเตอร์ได้รับการอนุรักษ์ไว้ในอดีตและสูญหายในภายหลัง จากนั้น โมเดลจะถูกแปลงเป็น TFLite เพื่อแสดงประโยชน์ของการบีบอัดของ chaining clustering และเทคนิคการเพิ่มประสิทธิภาพโมเดล CQAT และโมเดล TFLite ได้รับการประเมินเพื่อให้แน่ใจว่าความถูกต้องยังคงอยู่ในแบ็กเอนด์ TFLite สุดท้าย โมเดล CQAT ถูกเปรียบเทียบกับโมเดลคลัสเตอร์เชิงปริมาณที่ทำได้โดยใช้ API ควอนไทเซชั่นหลังการฝึก เพื่อแสดงให้เห็นถึงข้อได้เปรียบของ CQAT ในการกู้คืนการสูญเสียความแม่นยำจากการควอนไทเซชั่นปกติ