Ter uma questão? Conecte-se com a comunidade no Fórum TensorFlow Visite o Fórum

Treinamento ciente de quantização no exemplo Keras

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

Bem-vindo a um exemplo de ponta a ponta para treinamento ciente de quantização .

Outras páginas

Para obter uma introdução sobre o que é o treinamento baseado em quantização e para determinar se você deve usá-lo (incluindo o que é compatível), consulte a página de visão geral .

Para encontrar rapidamente as APIs de que você precisa para seu caso de uso (além de quantizar totalmente um modelo com 8 bits), consulte o guia completo .

Resumo

Neste tutorial, você irá:

  1. Treine um modelo tf.keras para MNIST do zero.
  2. Ajuste o modelo aplicando a API de treinamento com reconhecimento de quantização, veja a precisão e exporte um modelo com reconhecimento de quantização.
  3. Use o modelo para criar um modelo realmente quantizado para o back-end TFLite.
  4. Veja a persistência de precisão em TFLite e um modelo 4x menor. Para ver os benefícios da latência no celular, experimente os exemplos de TFLite no repositório de aplicativos TFLite .

Configurar

 pip uninstall -y tensorflow
 pip install -q tf-nightly
 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf

from tensorflow import keras

Treine um modelo para MNIST sem treinamento ciente de quantização

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
1688/1688 [==============================] - 8s 4ms/step - loss: 0.5175 - accuracy: 0.8544 - val_loss: 0.1275 - val_accuracy: 0.9648
<tensorflow.python.keras.callbacks.History at 0x7f37f03b39e8>

Clone e ajuste o modelo pré-treinado com treinamento ciente de quantização

Defina o modelo

Você aplicará o treinamento com reconhecimento de quantização a todo o modelo e verá isso no resumo do modelo. Todas as camadas agora são prefixadas por "quant".

Observe que o modelo resultante está ciente da quantização, mas não quantizado (por exemplo, os pesos são float32 em vez de int8). As seções a seguir mostram como criar um modelo quantizado a partir do que reconhece a quantização.

No guia completo , você pode ver como quantizar algumas camadas para melhorias na precisão do modelo.

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
=================================================================
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
_________________________________________________________________

Treine e avalie o modelo em relação à linha de base

Para demonstrar o ajuste fino depois de treinar o modelo para apenas uma época, faça o ajuste fino com treinamento ciente de quantização em um subconjunto dos dados de treinamento.

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 1s 167ms/step - loss: 0.1298 - accuracy: 0.9571 - val_loss: 0.1671 - val_accuracy: 0.9600
<tensorflow.python.keras.callbacks.History at 0x7f376c6db7b8>

Para este exemplo, há perda mínima ou nenhuma perda na precisão do teste após o treinamento ciente de quantização, em comparação com a linha de base.

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9609000086784363
Quant test accuracy: 0.9628999829292297

Crie um modelo quantizado para o back-end TFLite

Depois disso, você tem um modelo realmente quantizado com pesos int8 e ativações uint8.

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:109: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/tmpoct8ii0p/assets

Veja a persistência de precisão de TF para TFLite

Defina uma função auxiliar para avaliar o modelo TF Lite no conjunto de dados de teste.

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Você avalia o modelo quantizado e vê que a precisão do TensorFlow persiste no back-end TFLite.

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.963
Quant TF test accuracy: 0.9628999829292297

Veja o modelo 4x menor de quantização

Você cria um modelo TFLite flutuante e então vê que o modelo TFLite quantizado é 4x menor.

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmp/tmpy0445k6u/assets
INFO:tensorflow:Assets written to: /tmp/tmpy0445k6u/assets
Float model in Mb: 0.08053970336914062
Quantized model in Mb: 0.02339935302734375

Conclusão

Neste tutorial, você viu como criar modelos com reconhecimento de quantização com a API TensorFlow Model Optimization Toolkit e, em seguida, modelos quantizados para o back-end TFLite.

Você viu um benefício de compressão de tamanho de modelo 4x para um modelo para MNIST, com diferença mínima de precisão. Para ver os benefícios da latência no celular, experimente os exemplos de TFLite no repositório de aplicativos TFLite .

Recomendamos que você experimente esse novo recurso, que pode ser particularmente importante para implantação em ambientes com recursos limitados.