این صفحه به‌وسیله ‏Cloud Translation API‏ ترجمه شده است.
Switch to English

راهنمای جامع هرس

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

به راهنمای جامع هرس وزن کراس خوش آمدید.

این صفحه موارد مختلف استفاده را مستند می کند و نحوه استفاده از API را برای هر یک نشان می دهد. هنگامی که دانستید به کدام API ها نیاز دارید ، پارامترها و جزئیات سطح پایین را در اسناد API پیدا کنید .

  • اگر می خواهید مزایای هرس و موارد پشتیبانی شده را ببینید ، به نمای کلی مراجعه کنید.
  • برای یک نمونه منفرد از انتها به انتها ، به نمونه هرس مراجعه کنید.

موارد استفاده زیر پوشش داده شده است:

  • مدل هرس شده را تعریف و آموزش دهید.
    • متوالی و عملکردی.
    • Keras مدل. مناسب و حلقه های آموزش سفارشی
  • یک مدل هرس شده را بازرسی کرده و از حالت دلخواه خارج کنید.
  • یک مدل هرس شده را پیاده کنید و مزایای فشرده سازی را ببینید.

برای پیکربندی الگوریتم هرس ، به اسناد API tfmot.sparsity.keras.prune_low_magnitude مراجعه کنید.

برپایی

برای یافتن API های مورد نیاز و درک اهداف ، می توانید اجرا کنید اما از خواندن این بخش صرف نظر کنید.

! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
1/1 [==============================] - 0s 2ms/step - loss: 1.1999 - accuracy: 0.0000e+00

مدل را تعریف کنید

مدل کامل هرس (متوالی و عملکردی)

نکاتی برای دقت بهتر مدل:

  • "هرس بعضی از لایه ها" را امتحان کنید تا از هرس لایه هایی که بیشترین دقت را کاهش می دهند ، صرف نظر کنید.
  • به طور كلی بهتر است كه بر خلاف آموزش از ابتدا ، هرس را با هرس انجام دهید.

برای ساخت قطار کل مدل با هرس ، tfmot.sparsity.keras.prune_low_magnitude به مدل اعمال کنید.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:200: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_2  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

هر لایه را هرس کنید (ترتیبی و عملکردی)

هرس یک مدل می تواند تأثیر منفی بر دقت داشته باشد. می توانید لایه های یک مدل را به طور انتخابی هرس کنید تا معامله بین دقت ، سرعت و اندازه مدل را بررسی کنید.

نکاتی برای دقت بهتر مدل:

  • به طور كلی بهتر است كه بر خلاف آموزش از ابتدا ، هرس را با هرس انجام دهید.
  • سعی کنید به جای لایه های اول ، لایه های بعدی را هرس کنید.
  • از هرس لایه های حیاتی (مثلا مکانیزم توجه) خودداری کنید.

بیشتر :

در مثال زیر ، فقط لایه های Dense را هرس کنید.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_3  (None, 20)                822       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

در حالی که این مثال از نوع لایه برای تصمیم گیری در مورد هرس استفاده می کند ، ساده ترین راه برای هرس یک لایه خاص تنظیم ویژگی name آن و جستجوی آن نام در تابع clone_function .

print(base_model.layers[0].name)
dense_3

دقت مدل بیشتر خوانا اما بالقوه کمتر

این با تنظیم دقیق با هرس سازگار نیست ، به همین دلیل ممکن است از دقت کمتری نسبت به مثالهای بالا که از تنظیم دقیق پشتیبانی می کنند ، برخوردار باشد.

در حالی که prune_low_magnitude می توان هنگام تعریف مدل اولیه اعمال کرد ، بارگذاری اوزان پس از آن در مثالهای زیر کار نمی کند.

مثال عملکردی

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs=i, outputs=o)

model_for_pruning.summary()
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
prune_low_magnitude_dense_4  (None, 10)                412       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
=================================================================
Total params: 412
Trainable params: 210
Non-trainable params: 202
_________________________________________________________________

مثال متوالی

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_5  (None, 20)                822       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

لایه سفارشی Keras را هرس کنید یا قسمتهایی از لایه را برای هرس اصلاح کنید

اشتباه رایج: هرس تعصب معمولاً به دقت مدل خیلی آسیب می رساند.

tfmot.sparsity.keras.PrunableLayer دو مورد استفاده را ارائه می دهد:

  1. یک لایه سفارشی Keras را هرس کنید
  2. قسمت های یک لایه داخلی Keras را برای هرس اصلاح کنید.

به عنوان مثال ، API به طور پیش فرض فقط هرس هسته لایه Dense . مثال زیر نیز تعصب را قطع می کند.

class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_my_dense (None, 20)                843       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
=================================================================
Total params: 843
Trainable params: 420
Non-trainable params: 423
_________________________________________________________________

مدل قطار

مدل. مناسب

در طول آموزش با tfmot.sparsity.keras.UpdatePruningStep تماس بگیرید.

برای کمک به آموزش اشکال زدایی ، از tfmot.sparsity.keras.PruningSummaries استفاده کنید.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]

model_for_pruning.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

model_for_pruning.fit(
    x_train,
    y_train,
    callbacks=callbacks,
    epochs=2,
)

%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Epoch 1/2
1/1 [==============================] - 0s 3ms/step - loss: 1.2485 - accuracy: 0.0000e+00
Epoch 2/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
1/1 [==============================] - 0s 2ms/step - loss: 1.1999 - accuracy: 0.0000e+00

برای کاربران غیر Colab ، می توانید نتایج اجرای قبلی این بلوک کد را در TensorBoard.dev مشاهده کنید .

حلقه آموزش سفارشی

در طول آموزش با tfmot.sparsity.keras.UpdatePruningStep تماس بگیرید.

برای کمک به آموزش اشکال زدایی ، از پاسخ tfmot.sparsity.keras.PruningSummaries استفاده کنید.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback

    with tf.GradientTape() as tape:
      logits = model_for_pruning(x_train, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
      optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))

  step_callback.on_epoch_end(batch=unused_arg) # run pruning callback

%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

برای کاربران غیر Colab ، می توانید نتایج اجرای قبلی این بلوک کد را در TensorBoard.dev مشاهده کنید .

دقت مدل هرس شده را بهبود ببخشید

ابتدا به اسناد tfmot.sparsity.keras.prune_low_magnitude API نگاه کنید تا بفهمید برنامه هرس چیست و ریاضیات هر نوع برنامه هرس چیست.

نکات :

  • هنگام هرس مدل میزان یادگیری داشته باشید که خیلی زیاد یا خیلی پایین نباشد. برنامه هرس را یک ابر پارامتر در نظر بگیرید.

  • به عنوان یک تست سریع ، سعی کنید هرس یک مدل را برای پراکندگی نهایی در ابتدای آموزش با تنظیم begin_step روی 0 با یک برنامه tfmot.sparsity.keras.ConstantSparsity . ممکن است با نتایج خوب خوش شانس باشید.

  • خیلی مرتب هرس نکنید تا به مدل زمان دهید تا بازیابی شود. برنامه هرس فرکانس پیش فرض مناسبی را فراهم می کند.

  • برای ایده های کلی برای بهبود دقت مدل ، به دنبال نکاتی در مورد (موارد) استفاده خود در بخش "تعریف مدل" باشید.

بازرسی کنید و از حالت دلخواه خارج شوید

شما باید مرحله بهینه ساز را در هنگام checkpoint حفظ کنید. این بدان معنی است که در حالی که می توانید از مدل های Keras HDF5 برای ایست بازرسی استفاده کنید ، نمی توانید از وزنه های Keras HDF5 استفاده کنید.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

موارد فوق به طور کلی اعمال می شود. کد زیر فقط برای قالب مدل HDF5 مورد نیاز است (نه وزن های HDF5 و سایر قالب ها).

# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_8  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

استفاده از مدل هرس شده

صادرات مدل با فشرده سازی اندازه

اشتباه رایج : هر دو strip_pruning و استفاده از الگوریتم استاندارد فشرده سازی (به عنوان مثال از طریق gzip) برای دیدن مزایای فشرده سازی هرس ضروری است.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Typically you train the model here.

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
final model
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (None, 20)                420       
_________________________________________________________________
flatten_10 (Flatten)         (None, 20)                0         
=================================================================
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped pruned model without stripping: 3299.00 bytes
Size of gzipped pruned model with stripping: 2876.00 bytes

بهینه سازی های خاص سخت افزار

هنگامی که باطن های مختلف هرس را برای بهبود تأخیر فعال می کنند ، استفاده از پراکندگی بلوک می تواند تأخیر را برای سخت افزار خاصی بهبود بخشد.

افزایش اندازه بلوک ، اوج پراکندگی قابل دستیابی برای دقت مدل هدف را کاهش می دهد. با وجود این ، تأخیر هنوز هم می تواند بهبود یابد.

برای جزئیات بیشتر در مورد آنچه که برای tfmot.sparsity.keras.prune_low_magnitude بلوک پشتیبانی می شود ، به اسناد tfmot.sparsity.keras.prune_low_magnitude API مراجعه کنید.

base_model = setup_model()

# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_10 (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________