ترجمت واجهة Cloud Translation API‏ هذه الصفحة.
Switch to English

التدريب الموزع مع Keras

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

نظرة عامة

توفر tf.distribute.Strategy API فكرة مجردة لتوزيع tf.distribute.Strategy عبر وحدات معالجة متعددة. الهدف هو السماح للمستخدمين بتمكين التدريب الموزع باستخدام النماذج الحالية ورمز التدريب ، مع الحد الأدنى من التغييرات.

يستخدم هذا البرنامج التعليمي tf.distribute.MirroredStrategy ، التي تقوم tf.distribute.MirroredStrategy المتماثل في الرسم البياني مع التدريب المتزامن على العديد من وحدات معالجة الرسومات على جهاز واحد. بشكل أساسي ، يقوم بنسخ جميع متغيرات النموذج لكل معالج. بعد ذلك ، تستخدم تقنية all-Red لدمج التدرجات اللونية من جميع المعالجات وتطبيق القيمة المجمعة على جميع نسخ النموذج.

MirroredStrategy هي واحدة من عدة استراتيجيات توزيع متاحة في TensorFlow core. يمكنك قراءة المزيد من الاستراتيجيات في دليل إستراتيجية التوزيع .

واجهة برمجة تطبيقات Keras

يستخدم هذا المثال واجهة برمجة تطبيقات tf.keras لبناء النموذج وحلقة التدريب. للحصول على حلقات تدريب مخصصة ، راجع البرنامج التعليمي tf.distribute.Strategy مع حلقات التدريب .

تبعيات الاستيراد

# Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf

import os
print(tf.__version__)
2.3.0

قم بتنزيل مجموعة البيانات

قم بتنزيل مجموعة بيانات MNIST وحملها من مجموعات بيانات TensorFlow . يؤدي هذا إلى إرجاع مجموعة بيانات بتنسيق tf.data .

with_info الإعداد with_info إلى True البيانات الوصفية لمجموعة البيانات بأكملها ، والتي يتم حفظها هنا info . من بين أشياء أخرى ، يتضمن كائن البيانات الوصفية هذا عدد أمثلة القطار والاختبار.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

تحديد استراتيجية التوزيع

إنشاء كائن MirroredStrategy . سيعالج هذا التوزيع ، ويوفر مدير سياق ( tf.distribute.MirroredStrategy.scope ) لبناء نموذجك بالداخل.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

إعداد خط أنابيب الإدخال

عند تدريب نموذج باستخدام وحدات معالجة رسومات متعددة ، يمكنك استخدام قوة الحوسبة الإضافية بشكل فعال عن طريق زيادة حجم الدُفعة. بشكل عام ، استخدم أكبر حجم للدفعة يناسب ذاكرة وحدة معالجة الرسومات ، وقم بضبط معدل التعلم وفقًا لذلك.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

يجب تسوية قيم البكسل ، التي تتراوح من 0 إلى 255 ، إلى النطاق 0-1 . حدد هذا المقياس في دالة.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

قم بتطبيق هذه الوظيفة على بيانات التدريب والاختبار ، وخلط بيانات التدريب ، ثم قم بدفعها للتدريب . لاحظ أننا نحتفظ أيضًا بذاكرة تخزين مؤقت لبيانات التدريب في الذاكرة لتحسين الأداء.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

قم بإنشاء النموذج

خلق وتجميع نموذج Keras في سياق strategy.scope .

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

تحديد عمليات الاسترجاعات

عمليات الاسترجاعات المستخدمة هنا هي:

  • TensorBoard : يكتب رد الاتصال هذا سجلًا لـ TensorBoard والذي يسمح لك بتصور الرسوم البيانية.
  • Model Checkpoint : يحفظ رد الاتصال هذا النموذج بعد كل حقبة.
  • جدولة معدل التعلم : باستخدام رد الاتصال هذا ، يمكنك جدولة معدل التعلم للتغيير بعد كل فترة / دفعة.

لأغراض التوضيح ، أضف رد اتصال طباعة لعرض معدل التعلم في دفتر الملاحظات.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

تدريب وتقييم

الآن ، قم بتدريب النموذج بالطريقة المعتادة ، واستدعاء fit على النموذج والمرور في مجموعة البيانات التي تم إنشاؤها في بداية البرنامج التعليمي. هذه الخطوة هي نفسها سواء كنت تقوم بتوزيع التدريب أم لا.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

  1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks.

932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/12
935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812
Epoch 3/12
936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864
Epoch 4/12
937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936
Epoch 5/12
932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948
Epoch 6/12
919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951
Epoch 7/12
921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960
Epoch 8/12
931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970
Epoch 9/12
938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970
Epoch 10/12
924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971
Epoch 11/12
937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972
Epoch 12/12
923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973

<tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>

كما ترى أدناه ، يتم حفظ نقاط التفتيش.

# check the checkpoint directory
!ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

لمعرفة كيفية أداء النموذج ، قم بتحميل أحدث نقطة فحص واستدعاء evaluate على بيانات الاختبار.

evaluate المكالمة كما كان من قبل باستخدام مجموعات البيانات المناسبة.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

لمشاهدة المخرجات ، يمكنك تنزيل سجلات TensorBoard وعرضها على الجهاز.

$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K
4.0K train

تصدير إلى SavedModel

تصدير الرسم البياني والمتغيرات إلى تنسيق SavedModel الحيادي للنظام الأساسي. بعد حفظ النموذج ، يمكنك تحميله بالنطاق أو بدونه.

path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: saved_model/assets

INFO:tensorflow:Assets written to: saved_model/assets

قم بتحميل النموذج بدون strategy.scope .

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

قم بتحميل النموذج بـ strategy.scope .

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864
Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991

أمثلة وبرامج تعليمية

فيما يلي بعض الأمثلة لاستخدام إستراتيجية التوزيع مع keras fit / compile:

  1. تم تدريب مثال المحولات باستخدام tf.distribute.MirroredStrategy
  2. مثال NCF تم تدريبه باستخدام tf.distribute.MirroredStrategy .

المزيد من الأمثلة المدرجة في دليل استراتيجية التوزيع

الخطوات التالية