![]() | ![]() | ![]() | ![]() |
Aperçu
Bienvenue dans un exemple de bout en bout d' élagage de poids basé sur la magnitude.
Autres pages
Pour une introduction à ce qu'est l'élagage et pour déterminer si vous devez l'utiliser (y compris ce qui est pris en charge), consultez la page de présentation.
Pour trouver rapidement les API dont vous avez besoin pour votre cas d'utilisation (au-delà de l'élagage complet d'un modèle avec une parcimonie de 80%), consultez le guide complet .
Sommaire
Dans ce tutoriel, vous allez:
- Former un modèle
tf.keras
pour MNIST à partir de zéro. - Affinez le modèle en appliquant l'API d'élagage et voyez la précision.
- Créez des modèles TF et TFLite 3x plus petits à partir de l'élagage.
- Créez un modèle TFLite 10x plus petit en combinant l'élagage et la quantification post-entraînement.
- Voir la persistance de la précision de TF à TFLite.
Installer
pip install -q tensorflow-model-optimization
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
%load_ext tensorboard
Former un modèle pour MNIST sans taille
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
epochs=4,
validation_split=0.1,
)
Epoch 1/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.3422 - accuracy: 0.9004 - val_loss: 0.1760 - val_accuracy: 0.9498 Epoch 2/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.1813 - accuracy: 0.9457 - val_loss: 0.1176 - val_accuracy: 0.9698 Epoch 3/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.1220 - accuracy: 0.9648 - val_loss: 0.0864 - val_accuracy: 0.9770 Epoch 4/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0874 - accuracy: 0.9740 - val_loss: 0.0763 - val_accuracy: 0.9787 <tensorflow.python.keras.callbacks.History at 0x7f32cbeb9550>
Évaluez la précision du test de base et enregistrez le modèle pour une utilisation ultérieure.
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
Baseline test accuracy: 0.972599983215332 Saved baseline model to: /tmp/tmp6quew9ig.h5
Modèle pré-formé avec élagage
Définir le modèle
Vous appliquerez la taille à l'ensemble du modèle et le verrez dans le résumé du modèle.
Dans cet exemple, vous démarrez le modèle avec une parcimonie de 50% (50% de zéros en poids) et vous terminez avec une parcimonie de 80%.
Dans le guide complet , vous pouvez voir comment élaguer certaines couches pour améliorer la précision du modèle.
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.
num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_pruning.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:220: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.add_weight` method instead. Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_reshape (None, 28, 28, 1) 1 _________________________________________________________________ prune_low_magnitude_conv2d ( (None, 26, 26, 12) 230 _________________________________________________________________ prune_low_magnitude_max_pool (None, 13, 13, 12) 1 _________________________________________________________________ prune_low_magnitude_flatten (None, 2028) 1 _________________________________________________________________ prune_low_magnitude_dense (P (None, 10) 40572 ================================================================= Total params: 40,805 Trainable params: 20,410 Non-trainable params: 20,395 _________________________________________________________________
Former et évaluer le modèle par rapport à la ligne de base
Ajustez avec la taille pour deux époques.
tfmot.sparsity.keras.UpdatePruningStep
est requis pendant la formation, et tfmot.sparsity.keras.PruningSummaries
fournit des journaux pour suivre la progression et le débogage.
logdir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit(train_images, train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
Epoch 1/2 1/422 [..............................] - ETA: 0s - loss: 0.2689 - accuracy: 0.8984WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. 422/422 [==============================] - 3s 7ms/step - loss: 0.1105 - accuracy: 0.9691 - val_loss: 0.1247 - val_accuracy: 0.9682 Epoch 2/2 422/422 [==============================] - 3s 6ms/step - loss: 0.1197 - accuracy: 0.9667 - val_loss: 0.0969 - val_accuracy: 0.9763 <tensorflow.python.keras.callbacks.History at 0x7f32422a9550>
Pour cet exemple, il y a une perte minimale de précision du test après l'élagage, par rapport à la ligne de base.
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)
Baseline test accuracy: 0.972599983215332 Pruned test accuracy: 0.9689000248908997
Les journaux montrent la progression de la parcimonie sur une base par couche.
%tensorboard --logdir={logdir}
Pour les utilisateurs non-Colab, vous pouvez voirles résultats d'une précédente exécution de ce bloc de code sur TensorBoard.dev .
Créez des modèles 3x plus petits à partir de la taille
tfmot.sparsity.keras.strip_pruning
et l'application d'un algorithme de compression standard (par exemple via gzip) sont nécessaires pour voir les avantages de la compression de l'élagage.
-
strip_pruning
est nécessaire car il supprime chaque tf.Variable dont l'élagage n'a besoin que pendant l'entraînement, ce qui augmenterait autrement la taille du modèle lors de l'inférence - L'application d'un algorithme de compression standard est nécessaire car les matrices de poids sérialisées ont la même taille qu'avant l'élagage. Cependant, l'élagage rend la plupart des poids des zéros, ce qui ajoute une redondance que les algorithmes peuvent utiliser pour compresser davantage le modèle.
Tout d'abord, créez un modèle compressible pour TensorFlow.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
Saved pruned Keras model to: /tmp/tmpu92n0irx.h5
Ensuite, créez un modèle compressible pour TFLite.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()
_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
f.write(pruned_tflite_model)
print('Saved pruned TFLite model to:', pruned_tflite_file)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: /tmp/tmpunez1uhy/assets Saved pruned TFLite model to: /tmp/tmp9oa2swr6.tflite
Définissez une fonction d'assistance pour compresser réellement les modèles via gzip et mesurer la taille zippée.
def get_gzipped_model_size(file):
# Returns size of gzipped model, in bytes.
import os
import zipfile
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)
Comparez et voyez que les modèles sont 3 fois plus petits à partir de la taille.
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))
Size of gzipped baseline Keras model: 78048.00 bytes Size of gzipped pruned Keras model: 25680.00 bytes Size of gzipped pruned TFlite model: 24946.00 bytes
Créez un modèle 10x plus petit en combinant l'élagage et la quantification
Vous pouvez appliquer la quantification post-entraînement au modèle élagué pour des avantages supplémentaires.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(quantized_and_pruned_tflite_file, 'wb') as f:
f.write(quantized_and_pruned_tflite_model)
print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))
INFO:tensorflow:Assets written to: /tmp/tmpf68nyuwr/assets INFO:tensorflow:Assets written to: /tmp/tmpf68nyuwr/assets Saved quantized and pruned TFLite model to: /tmp/tmp85dhxupl.tflite Size of gzipped baseline Keras model: 78048.00 bytes Size of gzipped pruned and quantized TFlite model: 7663.00 bytes
Voir la persistance de la précision de TF à TFLite
Définissez une fonction d'assistance pour évaluer le modèle TF Lite sur l'ensemble de données de test.
import numpy as np
def evaluate_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on ever y image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print('Evaluated on {n} results so far.'.format(n=i))
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
Vous évaluez le modèle élagué et quantifié et constatez que la précision de TensorFlow persiste dans le backend TFLite.
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()
test_accuracy = evaluate_model(interpreter)
print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)
Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Pruned and quantized TFLite test_accuracy: 0.9692 Pruned TF test accuracy: 0.9689000248908997
Conclusion
Dans ce didacticiel, vous avez vu comment créer des modèles fragmentés avec l'API TensorFlow Model Optimization Toolkit pour TensorFlow et TFLite. Vous avez ensuite combiné la taille avec la quantification post-entraînement pour des avantages supplémentaires.
Vous avez créé un modèle 10x plus petit pour MNIST, avec une différence de précision minimale.
Nous vous encourageons à essayer cette nouvelle fonctionnalité, qui peut être particulièrement importante pour le déploiement dans des environnements à ressources limitées.