ترجمت واجهة Cloud Translation API‏ هذه الصفحة.
Switch to English

تدريب متعدد العاملين مع Keras

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تنزيل دفتر الملاحظات

نظرة عامة

يوضح هذا البرنامج التعليمي التدريب الموزع متعدد العمال باستخدام نموذج tf.distribute.Strategy باستخدام tf.distribute.Strategy API ، على وجه التحديد tf.distribute.experimental.MultiWorkerMirroredStrategy . بمساعدة هذه الإستراتيجية ، يمكن أن يعمل نموذج Keras الذي تم تصميمه للعمل على عامل واحد بسلاسة على عدة عمال بأقل تغيير للرمز.

يتوفر التدريب الموزع في دليل TensorFlow للحصول على نظرة عامة حول استراتيجيات التوزيع التي يدعمها TensorFlow للمهتمين بفهم أعمق لـ tf.distribute.Strategy APIs.

اقامة

أولاً ، قم بإعداد TensorFlow والواردات اللازمة.

 import os
import tensorflow as tf
import numpy as np
 

تحضير مجموعة البيانات

الآن ، دعنا نعد مجموعة بيانات MNIST. تشتمل مجموعة بيانات MNIST على 60.000 مثال تدريبي و 10000 اختبار تجريبي للأرقام المكتوبة بخط اليد من 0 إلى 9 ، بتنسيق صور أحادية اللون بحجم 28 × 28 بكسل. في هذا المثال ، سنأخذ الجزء التدريبي من مجموعات البيانات للتوضيح.

 def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # We need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset
 

قم ببناء نموذج Keras

هنا نستخدم tf.keras.Sequential API لبناء وتجميع شبكات عصبية تلافيفية بسيطة نموذج Keras للتدريب مع مجموعة بيانات MNIST الخاصة بنا.

 def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model
 

دعونا أولاً نحاول تدريب النموذج على عدد صغير من العهود ونلاحظ النتائج في عامل واحد للتأكد من أن كل شيء يعمل بشكل صحيح. يجب أن تتوقع رؤية انخفاض الخسارة والدقة تقترب من 1.0 مع تقدم الحقبة.

 per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
 
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/3
70/70 [==============================] - 0s 2ms/step - loss: 2.2701 - accuracy: 0.2451
Epoch 2/3
70/70 [==============================] - 0s 2ms/step - loss: 2.1827 - accuracy: 0.4777
Epoch 3/3
70/70 [==============================] - 0s 2ms/step - loss: 2.0865 - accuracy: 0.5955

<tensorflow.python.keras.callbacks.History at 0x7fc59381ac50>

تكوين متعدد العمال

الآن دعونا ندخل عالم التدريب متعدد العمال. في TensorFlow ، TF_CONFIG متغير بيئة TF_CONFIG للتدريب على أجهزة متعددة ، كل منها من المحتمل أن يكون له دور مختلف. TF_CONFIG عبارة عن سلسلة JSON تُستخدم لتحديد تكوين الكتلة لكل عامل يمثل جزءًا من المجموعة.

هناك مكونان من TF_CONFIG : cluster task . يوفر cluster معلومات حول مجموعة التدريب ، وهو عبارة عن إملاء يتكون من أنواع مختلفة من الوظائف مثل worker . في التدريب متعدد العمال باستخدام MultiWorkerMirroredStrategy ، عادة ما يكون هناك worker واحد يتحمل مسؤولية أكثر قليلاً مثل حفظ نقطة التفتيش وكتابة ملف ملخص لـ TensorBoard بالإضافة إلى ما يفعله worker العادي. ويشار إلى هذا العامل باعتباره chief عامل، ومن المعتاد أن worker مع index 0 تم تعيينه رئيسا لمجلس الرئيس worker (في الواقع هذه هي الطريقة tf.distribute.Strategy ينفذ). task من ناحية أخرى توفر معلومات عن المهمة الحالية. cluster المكون الأول هي نفسها لجميع العمال ، task المكون الثاني مختلفة على كل عامل وتحدد type هذا العامل index .

في هذا المثال ، قمنا بتعيين type المهمة إلى "worker" index المهام إلى 0 . هذا يعني أن الجهاز الذي لديه مثل هذا الإعداد هو العامل الأول ، والذي سيتم تعيينه كعامل رئيسي ويقوم بعمل أكثر من العمال الآخرين. لاحظ أن الأجهزة الأخرى ستحتاج إلى مجموعة متغيرات بيئة TF_CONFIG أيضًا ، ويجب أن يكون لها نفس إملاء cluster ، ولكن type مهمة أو index مهام مختلف اعتمادًا على أدوار هذه الأجهزة.

لأغراض التوضيح ، يوضح هذا البرنامج التعليمي كيف يمكن للمرء تعيين TF_CONFIG مع عاملين على localhost . من الناحية العملية ، يمكن للمستخدمين إنشاء عدة عمال على عناوين / منافذ IP الخارجية ، وتعيين TF_CONFIG على كل عامل بشكل مناسب.

 os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})
 

لاحظ أنه على الرغم من أن معدل التعلم ثابت في هذا المثال ، فقد يكون من الضروري بشكل عام تعديل معدل التعلم بناءً على حجم الدفعة العالمية.

اختر الاستراتيجية الصحيحة

في TensorFlow ، يتكون التدريب الموزع من التدريب المتزامن ، حيث تتم مزامنة خطوات التدريب عبر العمال والنسخ المتماثلة ، والتدريب غير المتزامن ، حيث لا تتم مزامنة خطوات التدريب بشكل صارم.

MultiWorkerMirroredStrategy ، وهي الإستراتيجية الموصى بها للتدريب المتزامن متعدد العاملين ، في هذا الدليل. لتدريب النموذج ، استخدم مثيل tf.distribute.experimental.MultiWorkerMirroredStrategy . يقوم MultiWorkerMirroredStrategy بإنشاء نسخ من جميع المتغيرات في طبقات النموذج على كل جهاز عبر جميع العاملين. يستخدم CollectiveOps ، وهو TensorFlow للتواصل الجماعي ، لتجميع التدرجات والحفاظ على المتغيرات متزامنة. يحتوي دليل tf.distribute.Strategy على مزيد من التفاصيل حول هذه الإستراتيجية.

 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

يوفر MultiWorkerMirroredStrategy عمليات تنفيذ متعددة عبر معلمة CollectiveCommunication . تنفذ RING التجمعات القائمة على الحلقة باستخدام gRPC كطبقة اتصال عبر المضيف. تستخدم NCCL Nvidia لتنفيذ المجموعات. AUTO يؤجل الاختيار إلى وقت التشغيل. يعتمد الاختيار الأفضل للتنفيذ الجماعي على عدد ونوع وحدات معالجة الرسومات ووصلات الشبكة في المجموعة.

درب النموذج باستخدام MultiWorkerMirroredStr Strategy

مع دمج tf.distribute.Strategy API في tf.keras ، فإن التغيير الوحيد الذي ستقوم به لتوزيع التدريب على عدة عمال هو إرفاق مبنى النموذج model.compile() داخل model.compile() strategy.scope() . يحدد نطاق استراتيجية التوزيع كيف وأين يتم إنشاء المتغيرات ، وفي حالة MultiWorkerMirroredStrategy ، فإن المتغيرات التي تم إنشاؤها هي MirroredVariable s ، ويتم نسخها على كل عامل.

 num_workers = 4

# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = build_and_compile_cnn_model()

# Keras' `model.fit()` trains the model with specified number of epochs and
# number of steps per epoch. Note that the numbers here are for demonstration
# purposes only and may not sufficiently produce a model with good quality.
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
 
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
70/70 [==============================] - 0s 3ms/step - loss: 2.2682 - accuracy: 0.2265
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1714 - accuracy: 0.4954
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.0638 - accuracy: 0.6232

<tensorflow.python.keras.callbacks.History at 0x7fc5f4f062e8>

تقسيم مجموعة البيانات وحجم الدفعة

في التدريب متعدد العمال باستخدام MultiWorkerMirroredStrategy ، يلزم تقسيم مجموعة البيانات لضمان التقارب والأداء. ومع ذلك ، لاحظ أنه في مقتطف الشفرة أعلاه ، يتم تمرير مجموعات البيانات مباشرة إلى model.fit() دون الحاجة إلى جزء ؛ وذلك لأن tf.distribute.Strategy API تتولى عملية تقسيم مجموعة البيانات تلقائيًا. يقوم بتجزئة مجموعة البيانات على مستوى الملف مما قد يؤدي إلى إنشاء أجزاء منحرفة. في الحالات القصوى التي يوجد فيها ملف واحد فقط ، فإن الجزء الأول فقط (أي العامل) سيحصل على بيانات التدريب أو التقييم ونتيجة لذلك سيحصل جميع العمال على أخطاء.

إذا كنت تفضل استخدام ميزة "الدمج اليدوي" للتدريب ، فيمكن إيقاف تشغيل "المشاركة tf.data.experimental.DistributeOptions عبر tf.data.experimental.DistributeOptions api. بشكل ملموس ،

 options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
dataset_no_auto_shard = multi_worker_dataset.with_options(options)
 

شيء آخر يجب ملاحظته هو حجم الدفعة datasets . في مقتطف الشفرة أعلاه ، نستخدم global_batch_size = per_worker_batch_size * num_workers ، وهو عدد num_workers أكبر من الحالة التي كان عليها للعامل الواحد ، لأن الحجم الفعال لكل عامل هو حجم الدفعة العامة (المعلمة التي تم تمريرها في tf.data.Dataset.batch() مقسومًا على عدد العمال ، وبهذا التغيير نحافظ على حجم الدفعة لكل عامل كما كان من قبل.

تقييم

إذا قمت بتمرير validation_data في model.fit ، model.fit بين التدريب والتقييم لكل حقبة. يتم توزيع بيانات validation_data على نفس مجموعة العمال ويتم تجميع نتائج التقييم وإتاحتها لجميع العمال. على غرار التدريب ، يتم تقسيم مجموعة بيانات التحقق تلقائيًا إلى مستوى الملف. تحتاج إلى تعيين حجم دُفعة عامة في مجموعة بيانات validation_steps وتعيين validation_steps . يوصى أيضًا بإجراء مجموعة بيانات متكررة للتقييم.

بدلاً من ذلك ، يمكنك أيضًا إنشاء مهمة أخرى تقرأ نقاط التفتيش بشكل دوري وتقوم بالتقييم. هذا ما يفعله المُقَيِّم. ولكن هذه ليست طريقة موصى بها لإجراء التقييم وبالتالي يتم حذف تفاصيلها.

تنبؤ

لا يعمل model.predict حاليًا مع MultiWorkerMirroredStrategy.

أداء

لديك الآن نموذج Keras الذي تم إعداده بالكامل للعمل في عدة عمال باستخدام MultiWorkerMirroredStrategy . يمكنك تجربة الأساليب التالية MultiWorkerMirroredStrategy أداء التدريب متعدد العاملين باستخدام MultiWorkerMirroredStrategy .

  • يوفر MultiWorkerMirroredStrategy العديد من تطبيقات الاتصالات الجماعية . تنفذ RING التجمعات القائمة على الحلقة باستخدام gRPC كطبقة اتصال عبر المضيف. تستخدم NCCL Nvidia لتنفيذ المجموعات. AUTO يؤجل الاختيار إلى وقت التشغيل. يعتمد الاختيار الأفضل للتنفيذ الجماعي على عدد ونوع وحدات معالجة الرسومات ووصلات الشبكة في المجموعة. لتجاوز الاختيار التلقائي ، حدد قيمة صالحة لمعلمة communication الخاصة MultiWorkerMirroredStrategy ، على سبيل المثال communication=tf.distribute.experimental.CollectiveCommunication.NCCL .
  • قم tf.float المتغيرات على tf.float إن أمكن. يشتمل نموذج ResNet الرسمي على مثال لكيفية القيام بذلك.

التسامح مع الخطأ

في التدريب المتزامن ، ستفشل المجموعة في حالة فشل أحد العمال وعدم وجود آلية لاسترداد الفشل. استخدام Keras مع tf.distribute.Strategy تأتي tf.distribute.Strategy مع ميزة تحمل الخطأ في الحالات التي يموت فيها العمال أو يكونون فيها غير مستقرين. نقوم بذلك عن طريق الحفاظ على حالة التدريب في نظام الملفات الموزعة الذي تختاره ، بحيث يتم استرداد حالة التدريب عند إعادة تشغيل الحالة التي فشلت أو استبقت سابقًا.

نظرًا لأن جميع العمال يتم الاحتفاظ بهم متزامنين من حيث فترات وخطوات التدريب ، فإن العمال الآخرين سيحتاجون إلى الانتظار حتى يعود العامل الفاشل أو المحكوم عليه للمتابعة.

رد اتصال ModelCheckpoint

لم يعد رد ModelCheckpoint يوفر وظيفة التسامح مع الخطأ ، الرجاء استخدام رد الاتصال BackupAndRestore بدلاً من ذلك.

لا يزال من الممكن استخدام رد ModelCheckpoint لحفظ نقاط التفتيش. ولكن مع ذلك ، إذا تم مقاطعة التدريب أو أنهى بنجاح ، من أجل مواصلة التدريب من نقطة التفتيش ، يكون المستخدم مسؤولاً عن تحميل النموذج يدويًا. اختياريًا ، يمكن للمستخدم اختيار حفظ واستعادة النموذج / الأوزان خارج رد ModelCheckpoint .

نموذج حفظ وتحميل

لحفظ النموذج الخاص بك باستخدام model.save أو tf.saved_model.save ، يجب أن تختلف وجهة الحفظ لكل عامل. على العمال غير الرئيسيين ، ستحتاج إلى حفظ النموذج في دليل مؤقت ، وفي الرئيس ، ستحتاج إلى الحفظ في دليل النموذج المقدم. يجب أن تكون الدلائل المؤقتة للعامل فريدة من نوعها لمنع الأخطاء الناتجة عن محاولة عدة عمال الكتابة إلى نفس الموقع. النموذج المحفوظ في جميع الدلائل متطابق وعادةً ما يُشار فقط إلى النموذج الذي حفظه الرئيس للاستعادة أو العرض. نوصي بأن يكون لديك بعض منطق التنظيف الذي يحذف الدلائل المؤقتة التي أنشأها العمال بمجرد الانتهاء من التدريب.

السبب الذي يجعلك تحتاج إلى الادخار على الرئيس والعمال في نفس الوقت ، هو أنك قد تجمع المتغيرات أثناء نقطة التفتيش والتي تتطلب من الرئيس والعمال المشاركة في بروتوكول allreduce للاتصال. من ناحية أخرى ، سيؤدي ترك الرئيس والعاملين في الحفظ إلى نفس دليل النموذج إلى حدوث أخطاء بسبب الخلاف.

باستخدام MultiWorkerMirroredStrategy ، يتم تشغيل البرنامج على كل عامل ، ومن أجل معرفة ما إذا كان العامل الحالي هو الرئيس ، فإننا نستفيد من كائن محلل الكتلة الذي يحتوي على سمات task_type و task_id . يخبرك task_type بالمهمة الحالية (مثل "عامل") ، ويخبرك task_id بمعرف العامل. يتم تعيين العامل برقم التعريف 0 كعامل رئيسي.

في مقتطف الرمز أدناه ، يوفر write_filepath مسار الملف للكتابة ، والذي يعتمد على معرف العامل. في حالة الرئيس (عامل بمعرف 0) ، يكتب إلى مسار الملف الأصلي ؛ بالنسبة للآخرين ، يقوم بإنشاء دليل مؤقت (بمعرف في مسار الدليل) للكتابة فيه:

 model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # If `task_type` is None, this may be operating as single worker, which works 
  # effectively as chief.
  return task_type is None or task_type == 'chief' or (
            task_type == 'worker' and task_id == 0)

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)
 

مع ذلك ، أنت الآن جاهز للحفظ:

 multi_worker_model.save(write_model_path)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras-model/assets

كما وصفنا أعلاه ، في وقت لاحق على النموذج يجب تحميله فقط من المسار الرئيسي المحفوظ له ، لذلك دعونا نزيل تلك المؤقتة التي حفظها العمال غير الرئيسيين:

 if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))
 

الآن ، عندما حان وقت التحميل ، دعنا نستخدم واجهة برمجة تطبيقات tf.keras.models.load_model ، ونواصل العمل الإضافي. هنا ، نفترض استخدام عامل واحد فقط لتحميل ومواصلة التدريب ، وفي هذه الحالة لا تتصل tf.keras.models.load_model ضمن strategy.scope() أخرى. نطاق strategy.scope() .

 loaded_model = tf.keras.models.load_model(model_path)

# Now that we have the model restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9825 - accuracy: 0.1102
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9367 - accuracy: 0.1117

<tensorflow.python.keras.callbacks.History at 0x7fc5f4b0d8d0>

حفظ واستعادة نقاط التفتيش

من ناحية أخرى ، تتيح لك نقاط التحقق حفظ أوزان النموذج واستعادتها دون الحاجة إلى حفظ النموذج بأكمله. هنا ، ستنشئ tf.train.Checkpoint الذي يتتبع النموذج ، والذي تتم إدارته بواسطة tf.train.CheckpointManager بحيث يتم الاحتفاظ فقط بنقطة التحقق الأخيرة.

 checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
  checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
 

بمجرد إعداد CheckpointManager ، فأنت الآن جاهز للحفظ وإزالة نقاط التفتيش غير الرئيسية للعمال المحفوظة.

 checkpoint_manager.save()
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)
 

الآن ، عندما تحتاج إلى الاستعادة ، يمكنك العثور على أحدث نقطة تحقق محفوظة باستخدام وظيفة tf.train.latest_checkpoint . بعد استعادة نقطة التفتيش ، يمكنك الاستمرار في التدريب.

 latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9841 - accuracy: 0.6561
Epoch 2/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9445 - accuracy: 0.6805

<tensorflow.python.keras.callbacks.History at 0x7fc5f49d9d30>

BackupAndRestore رد الاتصال

يوفر رد اتصال BackupAndRestore وظيفة التسامح مع الخطأ ، من خلال النسخ الاحتياطي للنموذج ورقم العصر الحالي في ملف نقطة تفتيش مؤقت ضمن وسيطة backup_dir إلى BackupAndRestore . يتم ذلك في نهاية كل عصر.

بمجرد مقاطعة الوظائف وإعادة تشغيلها ، يقوم رد الاتصال باستعادة نقطة التفتيش الأخيرة ، ويستمر التدريب من بداية العصر المتقطع. سيتم التخلص من أي تدريب جزئي تم إجراؤه بالفعل في الحقبة غير المكتملة قبل الانقطاع ، بحيث لا يؤثر على حالة النموذج النهائي.

لاستخدامه ، قدم نسخة من tf.keras.callbacks.experimental.BackupAndRestore في استدعاء tf.keras.Model.fit() .

باستخدام MultiWorkerMirroredStrategy ، إذا تمت مقاطعة العامل ، تتوقف الكتلة بالكامل مؤقتًا حتى تتم إعادة تشغيل العامل الذي تمت مقاطعته. كما سيتم إعادة تشغيل العمال الآخرين ، وعاد العامل المتقطع إلى المجموعة. بعد ذلك ، يقرأ كل عامل ملف نقاط التحقق الذي تم حفظه مسبقًا ويلتقط حالته السابقة ، مما يسمح للمجموعة بالعودة إلى المزامنة. ثم يستمر التدريب.

يستخدم رد BackupAndRestore CheckpointManager لحفظ حالة التدريب واستعادتها ، مما يؤدي إلى إنشاء ملف يسمى نقطة التحقق التي تتعقب نقاط التفتيش الحالية مع أحدثها. لهذا السبب ، لا ينبغي إعادة استخدام backup_dir لتخزين نقاط التفتيش الأخرى لتجنب تضارب الأسماء.

في الوقت الحالي ، يدعم رد الاتصال BackupAndRestore العامل المنفرد بدون إستراتيجية ، و MirroredStr Strategy ، ومتعدد العمال مع MultiWorkerMirroredStrategy. فيما يلي مثالان لكل من التدريب متعدد العمال وتدريب العامل الواحد.

 # Multi-worker training with MultiWorkerMirroredStrategy.

callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)
 
Epoch 1/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2837 - accuracy: 0.1836
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2131 - accuracy: 0.4091
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1310 - accuracy: 0.5485

<tensorflow.python.keras.callbacks.History at 0x7fc5f49a3080>

إذا قمت بفحص دليل backup_dir الذي حددته في BackupAndRestore ، فقد تلاحظ بعض ملفات نقاط التحقق التي تم إنشاؤها مؤقتًا. هذه الملفات مطلوبة لاستعادة الحالات المفقودة سابقًا ، وستتم إزالتها من قبل المكتبة في نهاية tf.keras.Model.fit() عند الخروج بنجاح من التدريب الخاص بك.

أنظر أيضا

  1. يقدم التدريب الموزع في دليل TensorFlow لمحة عامة عن استراتيجيات التوزيع المتاحة.
  2. نماذج رسمية ، يمكن تكوين العديد منها لتشغيل استراتيجيات توزيع متعددة.
  3. يوفر قسم الأداء في الدليل معلومات حول الاستراتيجيات والأدوات الأخرى التي يمكنك استخدامها لتحسين أداء نماذج TensorFlow الخاصة بك.