Ngày Cộng đồng ML là ngày 9 tháng 11! Tham gia với chúng tôi để cập nhật từ TensorFlow, JAX, và nhiều hơn nữa Tìm hiểu thêm

Ví dụ về cắt tỉa trong Keras

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Tổng quat

Chào mừng bạn đến với một ví dụ từ đầu đến cuối cho việc cắt tỉa trọng lượng dựa trên độ lớn.

Những trang khác

Để biết phần giới thiệu về cắt tỉa là gì và để xác định xem bạn có nên sử dụng nó (bao gồm cả những gì được hỗ trợ) hay không, hãy xem trang tổng quan .

Để nhanh chóng tìm thấy các API bạn cần cho trường hợp sử dụng của mình (ngoài việc cắt tỉa hoàn toàn một mô hình có độ thưa thớt 80%), hãy xem hướng dẫn toàn diện .

Tóm lược

Trong hướng dẫn này, bạn sẽ:

  1. Đào tạo mô hình tf.keras cho MNIST từ đầu.
  2. Tinh chỉnh mô hình bằng cách áp dụng API cắt tỉa và xem độ chính xác.
  3. Tạo mô hình TF và TFLite nhỏ hơn 3 lần từ việc cắt tỉa.
  4. Tạo mô hình TFLite nhỏ hơn 10 lần từ việc kết hợp cắt tỉa và lượng tử hóa sau đào tạo.
  5. Xem độ chính xác lâu dài từ TF sang TFLite.

Thiết lập

 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow import keras

%load_ext tensorboard

Đào tạo một mô hình cho MNIST mà không cần cắt tỉa

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/4
1688/1688 [==============================] - 10s 6ms/step - loss: 0.2785 - accuracy: 0.9220 - val_loss: 0.1031 - val_accuracy: 0.9740
Epoch 2/4
1688/1688 [==============================] - 9s 5ms/step - loss: 0.1063 - accuracy: 0.9691 - val_loss: 0.0782 - val_accuracy: 0.9790
Epoch 3/4
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0815 - accuracy: 0.9765 - val_loss: 0.0788 - val_accuracy: 0.9775
Epoch 4/4
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0689 - accuracy: 0.9797 - val_loss: 0.0633 - val_accuracy: 0.9840
<tensorflow.python.keras.callbacks.History at 0x7f146fbd8bd0>

Đánh giá độ chính xác của thử nghiệm cơ bản và lưu mô hình để sử dụng sau này.

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
Baseline test accuracy: 0.9775999784469604
Saved baseline model to: /tmp/tmpjj6swf59.h5

Tinh chỉnh mô hình được đào tạo trước bằng cách cắt tỉa

Xác định mô hình

Bạn sẽ áp dụng cách cắt tỉa cho toàn bộ mô hình và xem điều này trong phần tóm tắt mô hình.

Trong ví dụ này, bạn bắt đầu mô hình với độ thưa thớt 50% (trọng số là 50%) và kết thúc với độ thưa thớt 80%.

Trong hướng dẫn toàn diện , bạn có thể xem cách cắt bớt một số lớp để cải thiện độ chính xác của mô hình.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:2191: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_reshape  (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d ( (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10)                40572     
=================================================================
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395
_________________________________________________________________

Đào tạo và đánh giá mô hình so với đường cơ sở

Tinh chỉnh với việc cắt tỉa trong hai kỷ nguyên.

tfmot.sparsity.keras.UpdatePruningStep là bắt buộc trong quá trình đào tạo và tfmot.sparsity.keras.PruningSummaries cung cấp nhật ký để theo dõi tiến trình và gỡ lỗi.

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
  3/422 [..............................] - ETA: 12s - loss: 0.0628 - accuracy: 0.9896  WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0076s). Check your callbacks.
422/422 [==============================] - 5s 9ms/step - loss: 0.0797 - accuracy: 0.9771 - val_loss: 0.0828 - val_accuracy: 0.9790
Epoch 2/2
422/422 [==============================] - 3s 8ms/step - loss: 0.0971 - accuracy: 0.9741 - val_loss: 0.0839 - val_accuracy: 0.9775
<tensorflow.python.keras.callbacks.History at 0x7f12e4502910>

Đối với ví dụ này, độ chính xác của bài kiểm tra sau khi cắt tỉa bị mất đi tối thiểu so với đường cơ sở.

_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)
Baseline test accuracy: 0.9775999784469604
Pruned test accuracy: 0.972100019454956

Các bản ghi cho thấy sự tiến triển của sự thưa thớt trên cơ sở mỗi lớp.

%tensorboard --logdir={logdir}

Đối với người dùng không phải Colab, bạn có thể xemkết quả của lần chạy khối mã nàytrước đó trên TensorBoard.dev .

Tạo các mô hình nhỏ hơn gấp 3 lần từ việc cắt tỉa

Cả tfmot.sparsity.keras.strip_pruning và việc áp dụng thuật toán nén tiêu chuẩn (ví dụ: qua gzip) đều cần thiết để thấy được lợi ích nén của việc lược bớt.

  • strip_pruning là cần thiết vì nó loại bỏ mọi tf.Variable mà việc cắt tỉa chỉ cần trong quá trình đào tạo, nếu không sẽ thêm vào kích thước mô hình trong quá trình suy luận
  • Việc áp dụng thuật toán nén tiêu chuẩn là cần thiết vì các ma trận trọng số được tuần tự hóa có cùng kích thước với chúng trước khi cắt bớt. Tuy nhiên, việc cắt bớt làm cho hầu hết các số không có trọng số, điều này được thêm vào sự dư thừa mà các thuật toán có thể sử dụng để nén mô hình hơn nữa.

Đầu tiên, tạo một mô hình có thể nén cho TensorFlow.

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
Saved pruned Keras model to: /tmp/tmp22u333hk.h5

Sau đó, tạo một mô hình có thể nén cho TFLite.

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp51falze0/assets
Saved pruned TFLite model to: /tmp/tmpehx5la2i.tflite

Xác định một chức năng trợ giúp để thực sự nén các mô hình thông qua gzip và đo kích thước đã nén.

def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

So sánh và thấy rằng các mô hình nhỏ hơn 3 lần từ việc cắt tỉa.

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))
Size of gzipped baseline Keras model: 78211.00 bytes
Size of gzipped pruned Keras model: 25797.00 bytes
Size of gzipped pruned TFlite model: 24995.00 bytes

Tạo một mô hình nhỏ hơn 10 lần từ việc kết hợp cắt tỉa và lượng tử hóa

Bạn có thể áp dụng lượng tử hóa sau đào tạo cho mô hình đã cắt bớt để có thêm lợi ích.

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp6tzu3z87/assets
INFO:tensorflow:Assets written to: /tmp/tmp6tzu3z87/assets
Saved quantized and pruned TFLite model to: /tmp/tmp0mvlkin1.tflite
Size of gzipped baseline Keras model: 78211.00 bytes
Size of gzipped pruned and quantized TFlite model: 8031.00 bytes

Xem độ chính xác lâu dài từ TF sang TFLite

Xác định một chức năng trợ giúp để đánh giá mô hình TF Lite trên tập dữ liệu thử nghiệm.

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Bạn đánh giá mô hình đã được lược bớt và lượng tử hóa và thấy rằng độ chính xác từ TensorFlow vẫn tồn tại đối với phần phụ trợ TFLite.

interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9722
Pruned TF test accuracy: 0.972100019454956

Phần kết luận

Trong hướng dẫn này, bạn đã biết cách tạo các mô hình thưa thớt với API bộ công cụ tối ưu hóa mô hình TensorFlow cho cả TensorFlow và TFLite. Sau đó, bạn kết hợp cắt tỉa với lượng hóa sau đào tạo để có thêm lợi ích.

Bạn đã tạo một mô hình nhỏ hơn 10 lần cho MNIST với sự khác biệt về độ chính xác tối thiểu.

Chúng tôi khuyến khích bạn thử khả năng mới này, khả năng này có thể đặc biệt quan trọng để triển khai trong các môi trường hạn chế về tài nguyên.