روز جامعه ML 9 نوامبر است! برای به روز رسانی از TensorFlow، JAX به ما بپیوندید، و بیشتر بیشتر بدانید

راهنمای جامع هرس

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

به راهنمای جامع هرس وزن Keras خوش آمدید.

این صفحه موارد استفاده مختلف را مستند می کند و نحوه استفاده از API را برای هر یک نشان می دهد. هنگامی که شما می دانید که رابط های برنامه کاربردی شما نیاز دارید، پیدا کردن پارامترها و جزئیات سطح پایین در اسناد API .

  • اگر شما می خواهید برای دیدن مزایای هرس و چه چیزی حمایت می شود، دیدن نمای کلی .
  • برای یک مثال تنها پایان به پایان، ببینید به عنوان مثال هرس .

موارد استفاده زیر تحت پوشش قرار می گیرد:

  • تعریف و آموزش یک مدل هرس شده.
    • متوالی و کاربردی.
    • حلقه های آموزشی Keras model.fit و سفارشی
  • یک مدل هرس شده را بازرسی و ضدعفونی کنید.
  • یک مدل هرس شده را مستقر کرده و مزایای فشرده سازی را مشاهده کنید.

برای پیکربندی از الگوریتم هرس، به مراجعه tfmot.sparsity.keras.prune_low_magnitude اسناد API.

برپایی

برای یافتن API های مورد نیاز و درک اهداف ، می توانید از خواندن این بخش صرف نظر کنید.

! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()

مدل را تعریف کنید

هرس کامل مدل (متوالی و عملکردی)

نکاتی برای دقت بهتر مدل:

  • سعی کنید "هرس چند لایه" را از هرس لایه هایی که بیشترین دقت را کاهش می دهند ، رد کنید.
  • به طور کلی بهتر است با هرس به پایان برسد ، در مقایسه با آموزش از ابتدا.

را به کل مدل قطار با هرس، درخواست tfmot.sparsity.keras.prune_low_magnitude به مدل.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:200: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_2  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

هرس چند لایه (متوالی و عملکردی)

هرس یک مدل می تواند بر دقت تأثیر منفی بگذارد. شما می توانید لایه های یک مدل را به صورت انتخابی هرس کنید تا تعادل بین دقت ، سرعت و اندازه مدل را بررسی کنید.

نکاتی برای دقت بهتر مدل:

  • به طور کلی بهتر است با هرس به پایان برسد ، در مقایسه با آموزش از ابتدا.
  • سعی کنید لایه های بعدی را به جای لایه های اول هرس کنید.
  • از هرس لایه های مهم (مانند مکانیسم توجه) خودداری کنید.

بیشتر:

در مثال زیر، آلو تنها Dense لایه.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_3  (None, 20)                822       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

در حالی که این مثال استفاده نوع لایه را به تصمیم بگیرند که چه به هرس، ساده ترین راه برای هرس کردن یک لایه خاص این است که مجموعه آن name اموال، و نگاه برای آن نام در clone_function .

print(base_model.layers[0].name)
dense_3

دقت مدل بیشتر قابل خواندن اما به طور بالقوه کمتر است

این با تنظیم دقیق با هرس سازگار نیست ، به همین دلیل ممکن است از دقت کمتری نسبت به مثالهای فوق که از تنظیم دقیق پشتیبانی می کنند ، برخوردار باشد.

در حالی که prune_low_magnitude می توان در حالی که تعریف الگوی اولیه، بارگیری وزن پس می کند در زیر نمونه هایی کار نمی کند اعمال می شود.

مثال عملکردی

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs=i, outputs=o)

model_for_pruning.summary()
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
prune_low_magnitude_dense_4  (None, 10)                412       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
=================================================================
Total params: 412
Trainable params: 210
Non-trainable params: 202
_________________________________________________________________

مثال متوالی

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_5  (None, 20)                822       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

هرس لایه Keras سفارشی را انجام دهید یا قسمتهایی از لایه را برای هرس اصلاح کنید

اشتباه رایج: هرس تعصب معمولا دقت مدل آسیب می رساند بیش از حد.

tfmot.sparsity.keras.PrunableLayer در خدمت دو از موارد استفاده:

  1. یک لایه سفارشی Keras را هرس کنید
  2. قسمتهایی از یک لایه داخلی Keras را برای هرس اصلاح کنید.

برای مثال، پیش فرض API به تنها هرس هسته Dense لایه. مثال زیر نیز تعصب را هرس می کند.

class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_my_dense (None, 20)                843       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
=================================================================
Total params: 843
Trainable params: 420
Non-trainable params: 423
_________________________________________________________________

مدل قطار

model.fit

تماس با tfmot.sparsity.keras.UpdatePruningStep در طول آموزش پاسخ به تماس.

برای آموزش اشکال زدایی کمک، استفاده از tfmot.sparsity.keras.PruningSummaries پاسخ به تماس.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]

model_for_pruning.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

model_for_pruning.fit(
    x_train,
    y_train,
    callbacks=callbacks,
    epochs=2,
)

%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Epoch 1/2
1/1 [==============================] - 0s 3ms/step - loss: 1.2485 - accuracy: 0.0000e+00
Epoch 2/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
1/1 [==============================] - 0s 2ms/step - loss: 1.1999 - accuracy: 0.0000e+00

برای غیر COLAB کاربران، شما می توانید ببینید که نتایج از اجرای قبلی این بلوک کد در TensorBoard.dev .

حلقه آموزشی سفارشی

تماس با tfmot.sparsity.keras.UpdatePruningStep در طول آموزش پاسخ به تماس.

برای آموزش اشکال زدایی کمک، استفاده از tfmot.sparsity.keras.PruningSummaries پاسخ به تماس.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback

    with tf.GradientTape() as tape:
      logits = model_for_pruning(x_train, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
      optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))

  step_callback.on_epoch_end(batch=unused_arg) # run pruning callback

%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

برای غیر COLAB کاربران، شما می توانید ببینید که نتایج از اجرای قبلی این بلوک کد در TensorBoard.dev .

بهبود دقت مدل هرس شده

اول، در نگاه tfmot.sparsity.keras.prune_low_magnitude اسناد API به درک آنچه یک برنامه هرس و ریاضی از هر نوع هرس برنامه.

نکات:

  • هنگامی که مدل در حال هرس شدن است میزان یادگیری آن زیاد یا زیاد نیست. در نظر بگیرید برنامه هرس می شود hyperparameter.

  • به عنوان یک آزمون سریع، سعی کنید تجربه با هرس یک مدل به sparsity نهایی در ابتدای آموزش با تنظیم begin_step 0 با tfmot.sparsity.keras.ConstantSparsity برنامه. ممکن است با نتایج خوب خوش شانس باشید.

  • اغلب اوقات هرس نکنید تا به مدل زمان بهبودی بدهید. برنامه هرس یک فرکانس به طور پیش فرض مناسب و معقول فراهم می کند.

  • برای ایده های کلی برای بهبود دقت مدل ، در مورد "تعریف مدل" به دنبال نکاتی در مورد موارد استفاده خود باشید.

بازرسی و ضدعفونی کردن

در مرحله بازرسی باید مرحله بهینه ساز را حفظ کنید. این بدان معناست که در حالی که می توانید از مدلهای Keras HDF5 برای بازرسی استفاده کنید ، نمی توانید از وزن Keras HDF5 استفاده کنید.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

موارد فوق به طور کلی صدق می کند. کد زیر فقط برای قالب مدل HDF5 (نه وزن HDF5 و سایر فرمت ها) مورد نیاز است.

# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_8  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

استقرار مدل هرس شده

مدل صادرات با فشرده سازی اندازه

اشتباه رایج: هر دو strip_pruning و استفاده از یک الگوریتم فشرده سازی استاندارد (به عنوان مثال از طریق از gzip) لازم به دیدن مزایای فشرده سازی از هرس می باشد.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Typically you train the model here.

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
final model
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (None, 20)                420       
_________________________________________________________________
flatten_10 (Flatten)         (None, 20)                0         
=================================================================
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped pruned model without stripping: 3299.00 bytes
Size of gzipped pruned model with stripping: 2876.00 bytes

بهینه سازی های خاص سخت افزار

هنگامی که پایانه (Backend) مختلف را قادر می سازد هرس به بهبود تاخیر ، با استفاده از sparsity بلوک می تواند برای سخت افزار خاص بهبود تاخیر.

افزایش اندازه بلوک ، اوج پراکنده ای را که برای دقت مدل هدف قابل دستیابی است ، کاهش می دهد. با وجود این ، تأخیر هنوز می تواند بهبود یابد.

برای جزئیات در مورد آنچه برای sparsity بلوک پشتیبانی، نگاه کنید به tfmot.sparsity.keras.prune_low_magnitude اسناد API.

base_model = setup_model()

# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_10 (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________