XNNPACK ile cihaz içi çıkarım için budama

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

İle Cihazdaki çıkarsama gecikme geliştirmek için budama Keras ağırlıkları üzerindeki kılavuzuna hoş geldiniz XNNPACK .

Bu kılavuz sunuyor yeni tanıtılan kullanım tfmot.sparsity.keras.PruningPolicy API ve kullanan çağdaş CPU'lar çoğunlukla evrışimlı modelleri hızlandırılması için nasıl kullanılabileceğini gösteriyor XNNPACK Seyrek çıkarım .

Kılavuz, model oluşturma sürecinin aşağıdaki adımlarını kapsar:

  • Yoğun taban çizgisini oluşturun ve eğitin
  • Budama ile ince ayar modeli
  • TFLite'a Dönüştür
  • Cihaz içi kıyaslama

Kılavuz, budama ile ince ayar için en iyi uygulamaları kapsamamaktadır. Bu konuyla ilgili daha ayrıntılı bilgi için lütfen kontrol edin kapsamlı bir rehber .

Kurulum

 pip install -q tf-nightly
 pip install -q tensorflow-model-optimization==0.5.1.dev0
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

Yoğun modeli oluşturun ve eğitin

Biz inşa etmek ve üzerinde sınıflandırma görev için basit bir temel CNN tren CIFAR10 veri kümesi.

# Load CIFAR10 dataset.
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train[:90%]', 'train[90%:]', 'test'],
    as_supervised=True,
    with_info=True,
)

# Normalize the input image so that each pixel value is between 0 and 1.
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.image.convert_image_dtype(image, tf.float32), label

# Load the data in batches of 128 images.
batch_size = 128
def prepare_dataset(ds, buffer_size=None):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  if buffer_size:
    ds = ds.shuffle(buffer_size)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds

ds_train = prepare_dataset(ds_train,
                           buffer_size=ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)

# Build the dense baseline model.
dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Compile and train the dense model for 10 epochs.
dense_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

dense_model.fit(
  ds_train,
  epochs=10,
  validation_data=ds_val)

# Evaluate the dense model.
_, dense_model_accuracy = dense_model.evaluate(ds_test, verbose=0)
Epoch 1/10
352/352 [==============================] - 16s 32ms/step - loss: 2.0021 - accuracy: 0.2716 - val_loss: 2.0871 - val_accuracy: 0.2106
Epoch 2/10
352/352 [==============================] - 9s 24ms/step - loss: 1.7056 - accuracy: 0.3779 - val_loss: 1.7434 - val_accuracy: 0.3364
Epoch 3/10
352/352 [==============================] - 8s 24ms/step - loss: 1.6049 - accuracy: 0.4144 - val_loss: 1.6463 - val_accuracy: 0.3834
Epoch 4/10
352/352 [==============================] - 8s 23ms/step - loss: 1.5485 - accuracy: 0.4359 - val_loss: 1.7435 - val_accuracy: 0.3808
Epoch 5/10
352/352 [==============================] - 8s 24ms/step - loss: 1.5099 - accuracy: 0.4516 - val_loss: 1.5217 - val_accuracy: 0.4300
Epoch 6/10
352/352 [==============================] - 9s 24ms/step - loss: 1.4806 - accuracy: 0.4632 - val_loss: 1.5367 - val_accuracy: 0.4404
Epoch 7/10
352/352 [==============================] - 8s 24ms/step - loss: 1.4548 - accuracy: 0.4724 - val_loss: 1.5238 - val_accuracy: 0.4470
Epoch 8/10
352/352 [==============================] - 8s 24ms/step - loss: 1.4401 - accuracy: 0.4782 - val_loss: 1.7590 - val_accuracy: 0.3754
Epoch 9/10
352/352 [==============================] - 8s 24ms/step - loss: 1.4255 - accuracy: 0.4859 - val_loss: 1.4854 - val_accuracy: 0.4598
Epoch 10/10
352/352 [==============================] - 8s 24ms/step - loss: 1.4127 - accuracy: 0.4889 - val_loss: 1.8831 - val_accuracy: 0.3708

Seyrek modeli oluşturun

Talimatına göre, kapsamlı bir rehber , biz uygulamak tfmot.sparsity.keras.prune_low_magnitude parametreler hedef Cihazdaki ivme yoluyla budama yani birlikte işlevini tfmot.sparsity.keras.PruneForLatencyOnXNNPack politikası.

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after after 5 epochs.
end_epoch = 5

num_iterations_per_epoch = len(ds_train)
end_step =  num_iterations_per_epoch * end_epoch

# Define parameters for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.25,
                                                               final_sparsity=0.75,
                                                               begin_step=0,
                                                               end_step=end_step),
      'pruning_policy': tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
}

# Try to apply pruning wrapper with pruning policy parameter.
try:
  model_for_pruning = prune_low_magnitude(dense_model, **pruning_params)
except ValueError as e:
  print(e)
Could not find a `GlobalAveragePooling2D` layer with `keepdims = True` in all output branches

Çağrı prune_low_magnitude sonuçları ValueError mesajla Could not find a GlobalAveragePooling2D layer with keepdims = True in all output branches . Mesaj modeli politikası ile budama için desteklenmez belirtir tfmot.sparsity.keras.PruneForLatencyOnXNNPack katman ve özellikle GlobalAveragePooling2D parametre gerektirir keepdims = True . Hadi düzeltme olduğunu ve tekrar başvuruda prune_low_magnitude fonksiyonu.

fixed_dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(keepdims=True),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Use the pretrained model for pruning instead of training from scratch.
fixed_dense_model.set_weights(dense_model.get_weights())

# Try to reapply pruning wrapper.
model_for_pruning = prune_low_magnitude(fixed_dense_model, **pruning_params)
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:2233: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '

Çağırma prune_low_magnitude modeli tam olarak desteklenir; yani hatasız tamamladı tfmot.sparsity.keras.PruneForLatencyOnXNNPack politikası ve kullanılarak hızlandırılabilir XNNPACK Seyrek çıkarım .

Seyrek modelde ince ayar yapın

Aşağıdaki budama örneği yoğun modelinin ağırlıkları kullanılarak, biz ince ayarlar seyrek modeli. Modelin ince ayarına %25 seyreklik ile başlıyoruz (ağırlıkların %25'i sıfıra ayarlanmış) ve %75 seyreklik ile bitiriyoruz.

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

model_for_pruning.fit(
  ds_train,
  epochs=15,
  validation_data=ds_val,
  callbacks=callbacks)

# Evaluate the dense model.
_, pruned_model_accuracy = model_for_pruning.evaluate(ds_test, verbose=0)

print('Dense model test accuracy:', dense_model_accuracy)
print('Pruned model test accuracy:', pruned_model_accuracy)
Epoch 1/15
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5065: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5065: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
352/352 [==============================] - 10s 25ms/step - loss: 1.4274 - accuracy: 0.4850 - val_loss: 1.5313 - val_accuracy: 0.4336
Epoch 2/15
352/352 [==============================] - 8s 24ms/step - loss: 1.4519 - accuracy: 0.4756 - val_loss: 2.2348 - val_accuracy: 0.3022
Epoch 3/15
352/352 [==============================] - 8s 23ms/step - loss: 1.4864 - accuracy: 0.4622 - val_loss: 1.7750 - val_accuracy: 0.3752
Epoch 4/15
352/352 [==============================] - 8s 24ms/step - loss: 1.4758 - accuracy: 0.4634 - val_loss: 1.7347 - val_accuracy: 0.3742
Epoch 5/15
352/352 [==============================] - 9s 24ms/step - loss: 1.4509 - accuracy: 0.4736 - val_loss: 1.6406 - val_accuracy: 0.4166
Epoch 6/15
352/352 [==============================] - 8s 24ms/step - loss: 1.4345 - accuracy: 0.4788 - val_loss: 1.7445 - val_accuracy: 0.3804
Epoch 7/15
352/352 [==============================] - 9s 24ms/step - loss: 1.4196 - accuracy: 0.4865 - val_loss: 2.5808 - val_accuracy: 0.2624
Epoch 8/15
352/352 [==============================] - 9s 25ms/step - loss: 1.4093 - accuracy: 0.4900 - val_loss: 1.5336 - val_accuracy: 0.4498
Epoch 9/15
352/352 [==============================] - 9s 24ms/step - loss: 1.4023 - accuracy: 0.4940 - val_loss: 1.9210 - val_accuracy: 0.3654
Epoch 10/15
352/352 [==============================] - 9s 24ms/step - loss: 1.3968 - accuracy: 0.4960 - val_loss: 1.5129 - val_accuracy: 0.4406
Epoch 11/15
352/352 [==============================] - 9s 24ms/step - loss: 1.3882 - accuracy: 0.4983 - val_loss: 1.7009 - val_accuracy: 0.3896
Epoch 12/15
352/352 [==============================] - 9s 25ms/step - loss: 1.3807 - accuracy: 0.5020 - val_loss: 2.3179 - val_accuracy: 0.2984
Epoch 13/15
352/352 [==============================] - 8s 24ms/step - loss: 1.3781 - accuracy: 0.5034 - val_loss: 1.6146 - val_accuracy: 0.4324
Epoch 14/15
352/352 [==============================] - 9s 25ms/step - loss: 1.3735 - accuracy: 0.5054 - val_loss: 2.3618 - val_accuracy: 0.3062
Epoch 15/15
352/352 [==============================] - 8s 24ms/step - loss: 1.3748 - accuracy: 0.5040 - val_loss: 1.5962 - val_accuracy: 0.4312
Dense model test accuracy: 0.37400001287460327
Pruned model test accuracy: 0.4334000051021576

Günlükler, seyrekliğin ilerlemesini katman bazında gösterir.

%tensorboard --logdir={logdir}

Budama ile ince ayardan sonra, test doğruluğu, yoğun modele kıyasla mütevazi bir gelişme (%43 ila %44) göstermektedir. Gecikmesinin cihaz ile ilgili karşılaştıralım TFLite kriter .

Model dönüştürme ve kıyaslama

TFLite içine budanmış modeli dönüştürmek için, biz yerine gerek PruneLowMagnitude aracılığıyla orijinal katmanlarla sarmalayıcılarını strip_pruning fonksiyonu. (Budanmış modelin ağırlıkları Aynı zamanda, model_for_pruning ) sıfır, çoğunlukla, biz bir optimizasyon geçerli olabilir tf.lite.Optimize.EXPERIMENTAL_SPARSITY verimli TFLite modeli sonucu depolamak için. Bu optimizasyon bayrağı, yoğun model için gerekli değildir.

converter = tf.lite.TFLiteConverter.from_keras_model(dense_model)
dense_tflite_model = converter.convert()

_, dense_tflite_file = tempfile.mkstemp('.tflite')
with open(dense_tflite_file, 'wb') as f:
  f.write(dense_tflite_model)

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.EXPERIMENTAL_SPARSITY]
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)
INFO:tensorflow:Assets written to: /tmp/tmp9is7dj3q/assets
INFO:tensorflow:Assets written to: /tmp/tmp9is7dj3q/assets
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp9kw8dwue/assets
INFO:tensorflow:Assets written to: /tmp/tmp9kw8dwue/assets

Talimatları takip TFLite Modeli Kıyaslama Aracı , biz aracı inşa Yoğun ve budanmış TFLite modelleri ve kriter cihazda iki modelde ile birlikte Android cihazı yükleyin.

! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/dense_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1
/bin/bash: adb: command not found
! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/pruned_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1
/bin/bash: adb: command not found

Pixel 4 Deneyler budanmış modeli için yoğun model ve 12us için 17us ortalama çıkarım sürede sonuçlandı. Cihazdaki kriterler berrak 5US hatta bu gibi küçük modeller için gecikme% 30 iyileşme göstermiştir. Deneyimlerimize göre, daha büyük modeller dayalı MobileNetV3 veya EfficientNet-lite gösteri benzer performans iyileştirmeleri. Hızlanma, 1x1 evrişimlerin genel modele göreli katkısına göre değişir.

Sonuç

Bu öğreticide, TF MOT API ve XNNPack tarafından sunulan yeni işlevleri kullanarak daha hızlı cihaz performansı için seyrek modellerin nasıl oluşturulabileceğini gösteriyoruz. Bu seyrek modeller, kalitelerini korurken veya hatta aşarken, yoğun muadillerinden daha küçük ve daha hızlıdır.

Modellerinizi cihaza dağıtmak için özellikle önemli olabilecek bu yeni özelliği denemenizi öneririz.