O Dia da Comunidade de ML é dia 9 de novembro! Junte-nos para atualização de TensorFlow, JAX, e mais Saiba mais

Exemplo de treinamento ciente de quantização de preservação de cluster (CQAT) Keras

Ver no TensorFlow.org Executar no Google Colab Ver no GitHub Baixar caderno

Visão geral

Este é um fim exemplo final mostrando o uso do cluster preservar API quantização formação consciente (CQAT), parte do pipeline de otimização de colaboração do TensorFlow modelo de otimização do Toolkit.

Outras páginas

Para uma introdução ao gasoduto e outras técnicas disponíveis, consulte a página de visão geral otimização colaborativo .

Conteúdo

No tutorial, você irá:

  1. Treinar um tf.keras modelo para o conjunto de dados MNIST a partir do zero.
  2. Ajuste o modelo com clustering e veja a precisão.
  3. Aplique QAT e observe a perda de clusters.
  4. Aplique o CQAT e observe se o clustering aplicado anteriormente foi preservado.
  5. Gere um modelo TFLite e observe os efeitos da aplicação de CQAT nele.
  6. Compare a precisão do modelo CQAT alcançada com um modelo quantizado usando a quantização pós-treinamento.

Configurar

Você pode executar este Notebook Jupyter em seu local de virtualenv ou colab . Para mais detalhes sobre a criação de dependências, consulte o guia de instalação .

 pip install -q tensorflow-model-optimization
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

Treine um modelo tf.keras para MNIST sem clustering

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2827 - accuracy: 0.9211 - val_loss: 0.1057 - val_accuracy: 0.9728
Epoch 2/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.1070 - accuracy: 0.9696 - val_loss: 0.0768 - val_accuracy: 0.9793
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0805 - accuracy: 0.9765 - val_loss: 0.0663 - val_accuracy: 0.9823
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0671 - accuracy: 0.9804 - val_loss: 0.0614 - val_accuracy: 0.9840
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0592 - accuracy: 0.9827 - val_loss: 0.0594 - val_accuracy: 0.9837
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0537 - accuracy: 0.9842 - val_loss: 0.0610 - val_accuracy: 0.9838
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0487 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9827
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0437 - accuracy: 0.9862 - val_loss: 0.0624 - val_accuracy: 0.9847
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0399 - accuracy: 0.9882 - val_loss: 0.0595 - val_accuracy: 0.9847
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0362 - accuracy: 0.9885 - val_loss: 0.0613 - val_accuracy: 0.9835
<tensorflow.python.keras.callbacks.History at 0x7f24a4cc7990>

Avalie o modelo de linha de base e salve-o para uso posterior

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707
Saving model to:  /tmp/tmppd3opqzk.h5

Agrupe e ajuste o modelo com 8 clusters

Aplique o cluster_weights() API para agrupar todo o modelo pré-treinados para demonstrar e observar a sua eficácia na redução do tamanho do modelo ao aplicar zip, mantendo a precisão. Para a melhor forma de usar a API para obter a melhor taxa de compressão, mantendo a sua precisão de destino, consulte o guia completo agrupamento .

Defina o modelo e aplique a API de clustering

O modelo precisa ser pré-treinado antes de usar a API de clustering.

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_reshape (ClusterWeig (None, 28, 28, 1)         0         
_________________________________________________________________
cluster_conv2d (ClusterWeigh (None, 26, 26, 12)        236       
_________________________________________________________________
cluster_max_pooling2d (Clust (None, 13, 13, 12)        0         
_________________________________________________________________
cluster_flatten (ClusterWeig (None, 2028)              0         
_________________________________________________________________
cluster_dense (ClusterWeight (None, 10)                40578     
=================================================================
Total params: 40,814
Trainable params: 20,426
Non-trainable params: 20,388
_________________________________________________________________

Ajuste o modelo e avalie a precisão em relação à linha de base

Ajuste o modelo com clustering para 3 épocas.

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0363 - accuracy: 0.9887 - val_loss: 0.0610 - val_accuracy: 0.9830
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0333 - accuracy: 0.9898 - val_loss: 0.0599 - val_accuracy: 0.9830
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0319 - accuracy: 0.9905 - val_loss: 0.0580 - val_accuracy: 0.9837
<tensorflow.python.keras.callbacks.History at 0x7f240d271990>

Defina funções auxiliares para calcular e imprimir o número de agrupamentos em cada kernel do modelo.

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Verifique se os kernels do modelo foram agrupados corretamente. Precisamos remover o wrapper de cluster primeiro.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

Para este exemplo, há uma perda mínima na precisão do teste após o agrupamento, em comparação com a linha de base.

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707
Clustered test accuracy: 0.9800999760627747

Aplique QAT e CQAT e verifique o efeito nos clusters de modelo em ambos os casos

Em seguida, aplicamos o QAT e o QAT de preservação de cluster (CQAT) no modelo em cluster e observamos que o CQAT preserva os clusters de peso em seu modelo em cluster. Note que nós despojado wrappers clusterização do seu modelo com tfmot.clustering.keras.strip_clustering antes de aplicar CQAT API.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 8ms/step - loss: 0.0328 - accuracy: 0.9901 - val_loss: 0.0578 - val_accuracy: 0.9840
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
422/422 [==============================] - 4s 8ms/step - loss: 0.0310 - accuracy: 0.9908 - val_loss: 0.0592 - val_accuracy: 0.9838
<tensorflow.python.keras.callbacks.History at 0x7f240c60e6d0>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19931 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters

Veja os benefícios da compressão do modelo CQAT

Defina a função auxiliar para obter o arquivo de modelo compactado.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Observe que este é um modelo pequeno. Aplicar clustering e CQAT a um modelo de produção maior resultaria em uma compressão mais significativa.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
INFO:tensorflow:Assets written to: /tmp/tmpy_b1e0tx/assets
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
INFO:tensorflow:Assets written to: /tmp/tmp282pfq0n/assets
QAT model size:  16.685  KB
CQAT model size:  10.121  KB

Veja a persistência de precisão de TF para TFLite

Defina uma função auxiliar para avaliar o modelo TFLite no conjunto de dados de teste.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Você avalia o modelo, que foi agrupado e quantizado, e vê que a precisão do TensorFlow persiste no back-end TFLite.

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9795
Clustered TF test accuracy: 0.9800999760627747

Aplicar a quantização pós-treinamento e comparar com o modelo CQAT

Em seguida, usamos a quantização pós-treinamento (sem ajuste fino) no modelo agrupado e verificamos sua precisão em relação ao modelo CQAT. Isso demonstra por que você precisaria usar o CQAT para melhorar a precisão do modelo quantizado.

Primeiro, defina um gerador para o conjunto de dados de calibração das primeiras 1000 imagens de treinamento.

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

Quantize o modelo e compare a precisão com o modelo CQAT adquirido anteriormente. Observe que o modelo quantizado com ajuste fino atinge maior precisão.

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
INFO:tensorflow:Assets written to: /tmp/tmp6vw7f4l2/assets
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9795
Post-training (no fine-tuning) TF test accuracy: 0.9804

Conclusão

Neste tutorial, você aprendeu como criar um modelo, aglomerar-lo usando as cluster_weights() API, e aplicar o treinamento de reconhecimento de cluster preservar quantização (CQAT) para preservar os clusters ao usar QAT. O modelo CQAT final foi comparado ao QAT para mostrar que os clusters são preservados no primeiro e perdidos no segundo. Em seguida, os modelos foram convertidos para TFLite para mostrar os benefícios de compressão do encadeamento de clustering e técnicas de otimização de modelo CQAT e o modelo TFLite foi avaliado para garantir que a precisão persiste no backend TFLite. Finalmente, o modelo CQAT foi comparado a um modelo agrupado quantizado obtido usando a API de quantização pós-treinamento para demonstrar a vantagem do CQAT em recuperar a perda de precisão da quantização normal.