Есть вопрос? Присоединяйтесь к сообществу на форуме TensorFlow. Посетите форум.

Обучение с учетом квантования на примере Keras

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Обзор

Добро пожаловать в комплексный пример обучения квантованию .

Другие страницы

Чтобы узнать, что такое обучение с учетом квантования, и определить, следует ли его использовать (включая то, что поддерживается), см. На странице обзора .

Чтобы быстро найти API-интерфейсы, необходимые для вашего варианта использования (помимо полного квантования модели с 8-битными данными), см. Полное руководство .

Резюме

В этом руководстве вы:

  1. tf.keras модель tf.keras для MNIST с нуля.
  2. Настройте модель, применив обучающий API с учетом квантования, проверьте точность и экспортируйте модель с учетом квантования.
  3. Используйте модель для создания фактически квантованной модели для бэкэнда TFLite.
  4. Посмотрите на неизменную точность TFLite и модели в 4 раза меньшего размера. Чтобы увидеть преимущества задержки на мобильных устройствах, попробуйте примеры TFLite в репозитории приложений TFLite .

Настраивать

 pip uninstall -y tensorflow
 pip install -q tf-nightly
 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf

from tensorflow import keras

Обучите модель для MNIST без обучения с учетом квантования

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
1688/1688 [==============================] - 8s 4ms/step - loss: 0.5175 - accuracy: 0.8544 - val_loss: 0.1275 - val_accuracy: 0.9648
<tensorflow.python.keras.callbacks.History at 0x7f37f03b39e8>

Клонирование и точная настройка предварительно обученной модели с обучением с учетом квантования

Определите модель

Вы примените обучение с учетом квантования ко всей модели и увидите это в сводке модели. Все слои теперь имеют префикс "Quant".

Обратите внимание, что результирующая модель учитывает квантование, но не квантуется (например, веса - float32 вместо int8). В следующих разделах показано, как создать квантованную модель из модели с учетом квантования.

В подробном руководстве вы можете увидеть, как квантовать некоторые слои для повышения точности модели.

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
=================================================================
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
_________________________________________________________________

Обучите и оцените модель относительно базовой линии

Чтобы продемонстрировать точную настройку после обучения модели только для эпохи, выполните точную настройку с обучением с учетом квантования на подмножестве данных обучения.

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 1s 167ms/step - loss: 0.1298 - accuracy: 0.9571 - val_loss: 0.1671 - val_accuracy: 0.9600
<tensorflow.python.keras.callbacks.History at 0x7f376c6db7b8>

В этом примере потеря точности теста после обучения с учетом квантования минимальна или отсутствует по сравнению с базовой линией.

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9609000086784363
Quant test accuracy: 0.9628999829292297

Создать квантованную модель для бэкэнда TFLite

После этого у вас есть фактически квантованная модель с весами int8 и активациями uint8.

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/tmpoct8ii0p/assets

Смотрите постоянство точности от TF до TFLite

Определите вспомогательную функцию для оценки модели TF Lite в тестовом наборе данных.

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Вы оцениваете квантованную модель и видите, что точность от TensorFlow сохраняется до бэкэнда TFLite.

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.963
Quant TF test accuracy: 0.9628999829292297

Увидеть в 4 раза меньшую модель из квантования

Вы создаете плавающую модель TFLite, а затем видите, что квантованная модель TFLite в 4 раза меньше.

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmp/tmpy0445k6u/assets
INFO:tensorflow:Assets written to: /tmp/tmpy0445k6u/assets
Float model in Mb: 0.08053970336914062
Quantized model in Mb: 0.02339935302734375

Заключение

В этом руководстве вы увидели, как создавать модели с учетом квантования с помощью TensorFlow Model Optimization Toolkit API, а затем квантованные модели для бэкэнда TFLite.

Вы увидели 4-кратное преимущество сжатия размера модели для модели для MNIST с минимальной разницей в точности. Чтобы увидеть преимущества задержки на мобильных устройствах, попробуйте примеры TFLite в репозитории приложений TFLite .

Мы рекомендуем вам попробовать эту новую возможность, которая может быть особенно важной для развертывания в средах с ограниченными ресурсами.