تدريب مخصص مع tf.distribute.Strategy

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

يوضح هذا البرنامج التعليمي كيفية استخدام tf.distribute.Strategy مع حلقات تدريب مخصصة. سنقوم بتدريب نموذج بسيط لشبكة CNN على مجموعة بيانات الموضة MNIST. تحتوي مجموعة بيانات Fashion MNIST على 60000 صورة قطار بحجم 28 × 28 و 10000 صورة اختبار بحجم 28 × 28.

نحن نستخدم حلقات تدريب مخصصة لتدريب نموذجنا لأنها تمنحنا مرونة وتحكمًا أكبر في التدريب. علاوة على ذلك ، من الأسهل تصحيح أخطاء النموذج وحلقة التدريب.

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.8.0-rc1

قم بتنزيل مجموعة بيانات Fashion MNIST

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

إنشاء استراتيجية لتوزيع المتغيرات والرسم البياني

كيف تعمل استراتيجية tf.distribute.MirroredStrategy ؟

  • يتم تكرار جميع المتغيرات والرسم البياني للنموذج على النسخ المتماثلة.
  • يتم توزيع الإدخال بالتساوي عبر النسخ المتماثلة.
  • تحسب كل نسخة متماثلة الخسارة والتدرجات للإدخال الذي تلقته.
  • تتم مزامنة التدرجات عبر جميع النسخ المتماثلة عن طريق جمعها.
  • بعد المزامنة ، يتم إجراء نفس التحديث على نسخ المتغيرات في كل نسخة متماثلة.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

إعداد خط أنابيب الإدخال

قم بتصدير الرسم البياني والمتغيرات إلى تنسيق SavedModel المحايد للنظام الأساسي. بعد حفظ النموذج الخاص بك ، يمكنك تحميله بالنطاق أو بدونه.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

قم بإنشاء مجموعات البيانات وتوزيعها:

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-01-26 05:45:53.991501: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 60000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

2022-01-26 05:45:54.034762: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:3"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

قم بإنشاء النموذج

قم بإنشاء نموذج باستخدام tf.keras.Sequential . يمكنك أيضًا استخدام Model Subclassing API للقيام بذلك.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

تحديد وظيفة الخسارة

عادةً ، في جهاز واحد مع 1 GPU / CPU ، يتم تقسيم الخسارة على عدد الأمثلة في دفعة الإدخال.

لذا ، كيف يجب حساب الخسارة عند استخدام استراتيجية tf.distribute.Strategy ؟

  • على سبيل المثال ، لنفترض أن لديك 4 وحدات معالجة رسومات وحجم دُفعة 64. تم توزيع دفعة واحدة من الإدخال عبر النسخ المتماثلة (4 وحدات معالجة رسومات) ، تحصل كل نسخة متماثلة على إدخال بحجم 16.

  • يقوم النموذج الموجود في كل نسخة متماثلة بتمرير أمامي مع المدخلات الخاصة به ويحسب الخسارة. الآن ، بدلاً من قسمة الخسارة على عدد الأمثلة في المدخلات الخاصة بها (BATCH_SIZE_PER_REPLICA = 16) ، يجب تقسيم الخسارة على GLOBAL_BATCH_SIZE (64).

لماذا فعل هذا؟

  • يجب القيام بذلك لأنه بعد حساب التدرجات اللونية في كل نسخة متماثلة ، تتم مزامنتها عبر النسخ المتماثلة عن طريق جمعها .

كيف يتم القيام بذلك في TensorFlow؟

  • إذا كنت تكتب حلقة تدريب مخصصة ، كما في هذا البرنامج التعليمي ، فيجب عليك جمع الخسائر لكل مثال وتقسيم المجموع على GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) أو يمكنك استخدام tf.nn.compute_average_loss الذي يأخذ الخسارة لكل مثال وأوزان العينة الاختيارية و GLOBAL_BATCH_SIZE كوسيطات ويعيد الخسارة المقاسة.

  • إذا كنت تستخدم خسائر تسوية في نموذجك ، فأنت بحاجة إلى قياس قيمة الخسارة حسب عدد النسخ المتماثلة. يمكنك القيام بذلك باستخدام دالة tf.nn.scale_regularization_loss .

  • لا يوصى باستخدام tf.reduce_mean . يؤدي القيام بذلك إلى تقسيم الخسارة على الحجم الفعلي لكل نسخة متماثلة والذي قد يختلف من خطوة إلى أخرى.

  • يتم إجراء هذا التخفيض والتحجيم تلقائيًا في model.compile و model.fit

  • في حالة استخدام فئات tf.keras.losses (كما في المثال أدناه) ، يجب تحديد تقليل الخسارة بشكل صريح ليكون واحدًا من NONE أو SUM . AUTO و SUM_OVER_BATCH_SIZE غير مسموح بهما عند الاستخدام مع tf.distribute.Strategy . لا يُسمح بـ AUTO لأن المستخدم يجب أن يفكر صراحة في التخفيض الذي يريده للتأكد من صحته في الحالة الموزعة. SUM_OVER_BATCH_SIZE غير مسموح به لأنه لن يتم القسمة حاليًا إلا على حجم كل دفعة نسخة متماثلة ، ويترك القسمة على عدد النسخ المتماثلة للمستخدم ، الأمر الذي قد يكون من السهل تفويته. لذا بدلاً من ذلك ، نطلب من المستخدم إجراء التخفيض بنفسه صراحةً.

  • إذا كانت labels متعددة الأبعاد ، فحينئذٍ متوسط per_example_loss لكل عينة عبر عدد العناصر في كل عينة. على سبيل المثال ، إذا كان شكل predictions هو (batch_size, H, W, n_classes) وكانت labels (batch_size, H, W) ، فستحتاج إلى تحديث per_example_loss مثل: per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

حدد المقاييس لتتبع الخسارة والدقة

تتعقب هذه المقاييس فقدان الاختبار والتدريب ودقة الاختبار. يمكنك استخدام .result() للحصول على الإحصائيات المتراكمة في أي وقت.

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

حلقة التدريب

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
Epoch 1, Loss: 0.5106383562088013, Accuracy: 81.77999877929688, Test Loss: 0.39399346709251404, Test Accuracy: 85.79000091552734
Epoch 2, Loss: 0.3362727463245392, Accuracy: 87.91333770751953, Test Loss: 0.35871225595474243, Test Accuracy: 86.7699966430664
Epoch 3, Loss: 0.2928692400455475, Accuracy: 89.2683334350586, Test Loss: 0.2999486029148102, Test Accuracy: 89.04000091552734
Epoch 4, Loss: 0.2605818510055542, Accuracy: 90.41999816894531, Test Loss: 0.28474125266075134, Test Accuracy: 89.47000122070312
Epoch 5, Loss: 0.23641237616539001, Accuracy: 91.32166290283203, Test Loss: 0.26421546936035156, Test Accuracy: 90.41000366210938
Epoch 6, Loss: 0.2192477434873581, Accuracy: 91.90499877929688, Test Loss: 0.2650589942932129, Test Accuracy: 90.4800033569336
Epoch 7, Loss: 0.20016911625862122, Accuracy: 92.66999816894531, Test Loss: 0.25025954842567444, Test Accuracy: 90.9000015258789
Epoch 8, Loss: 0.18381091952323914, Accuracy: 93.26499938964844, Test Loss: 0.2585820257663727, Test Accuracy: 90.95999908447266
Epoch 9, Loss: 0.1699329912662506, Accuracy: 93.67500305175781, Test Loss: 0.26234227418899536, Test Accuracy: 91.0199966430664
Epoch 10, Loss: 0.15756534039974213, Accuracy: 94.16333770751953, Test Loss: 0.25516414642333984, Test Accuracy: 90.93000030517578

أشياء يجب ملاحظتها في المثال أعلاه:

  • نحن نكرر على مجموعة train_dist_dataset و test_dist_dataset باستخدام a for x in ... build.
  • الخسارة المقاسة هي القيمة المرجعة للخطوة distributed_train_step . يتم تجميع هذه القيمة عبر النسخ المتماثلة باستخدام الاستدعاء tf.distribute.Strategy.reduce ثم عبر الدُفعات عن طريق جمع قيمة الإرجاع للمكالمات tf.distribute.Strategy.reduce .
  • يجب تحديث tf.keras.Metrics داخل train_step و test_step التي يتم تنفيذها بواسطة tf.distribute.Strategy.run . * ترجع tf.distribute.Strategy.run النتائج من كل نسخة متماثلة محلية في الإستراتيجية ، وهناك طرق متعددة لاستهلاك هذه النتيجة. يمكنك tf.distribute.Strategy.reduce للحصول على قيمة مجمعة. يمكنك أيضًا إجراء tf.distribute.Strategy.experimental_local_results للحصول على قائمة القيم الموجودة في النتيجة ، واحدة لكل نسخة محلية.

استعادة أحدث نقطة تفتيش والاختبار

نموذج تم التحقق منه باستخدام tf.distribute.Strategy يمكن استعادة الإستراتيجية مع أو بدون إستراتيجية.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 91.0199966430664

طرق بديلة للتكرار على مجموعة بيانات

باستخدام التكرارات

إذا كنت ترغب في تكرار عدد معين من الخطوات وليس من خلال مجموعة البيانات بأكملها ، يمكنك إنشاء مكرر باستخدام استدعاء iter واستدعاء الوضوح next في المكرر. يمكنك اختيار التكرار على مجموعة البيانات داخل وخارج وظيفة tf. فيما يلي مقتطف صغير يوضح تكرار مجموعة البيانات خارج دالة tf باستخدام مكرر.

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.17486707866191864, Accuracy: 93.4375
Epoch 10, Loss: 0.12386945635080338, Accuracy: 95.3125
Epoch 10, Loss: 0.16411852836608887, Accuracy: 93.90625
Epoch 10, Loss: 0.10728752613067627, Accuracy: 96.40625
Epoch 10, Loss: 0.11865834891796112, Accuracy: 95.625
Epoch 10, Loss: 0.12875251471996307, Accuracy: 95.15625
Epoch 10, Loss: 0.1189488023519516, Accuracy: 95.625
Epoch 10, Loss: 0.1456708014011383, Accuracy: 95.15625
Epoch 10, Loss: 0.12446556240320206, Accuracy: 95.3125
Epoch 10, Loss: 0.1380888819694519, Accuracy: 95.46875

التكرار داخل دالة tf

يمكنك أيضًا التكرار عبر مجموعة train_dist_dataset الإدخال بالكامل داخل دالة tf باستخدام for x in ... build أو عن طريق إنشاء مكررات كما فعلنا أعلاه. يوضح المثال أدناه التفاف حقبة واحدة من التدريب في دالة tf والتكرار عبر train_dist_dataset داخل الوظيفة.

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:449: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
Epoch 1, Loss: 0.14398494362831116, Accuracy: 94.63999938964844
Epoch 2, Loss: 0.13246288895606995, Accuracy: 94.97333526611328
Epoch 3, Loss: 0.11922841519117355, Accuracy: 95.63833618164062
Epoch 4, Loss: 0.11084160208702087, Accuracy: 95.99333190917969
Epoch 5, Loss: 0.10420522093772888, Accuracy: 96.0816650390625
Epoch 6, Loss: 0.09215126931667328, Accuracy: 96.63500213623047
Epoch 7, Loss: 0.0878651961684227, Accuracy: 96.67666625976562
Epoch 8, Loss: 0.07854588329792023, Accuracy: 97.09333038330078
Epoch 9, Loss: 0.07217177003622055, Accuracy: 97.34833526611328
Epoch 10, Loss: 0.06753655523061752, Accuracy: 97.48999786376953

تتبع فقدان التدريب عبر النسخ المتماثلة

لا نوصي باستخدام tf.metrics.Mean لتتبع فقدان التدريب عبر النسخ المتماثلة المختلفة ، بسبب حساب قياس الخسارة الذي يتم تنفيذه.

على سبيل المثال ، إذا كنت تدير وظيفة تدريبية بالخصائص التالية:

  • نسختان متماثلتان
  • تتم معالجة عينتين على كل نسخة طبق الأصل
  • قيم الخسارة الناتجة: [2 ، 3] و [4 ، 5] على كل نسخة متماثلة
  • حجم الدفعة العالمية = 4

باستخدام مقياس الخسارة ، يمكنك حساب قيمة الخسارة لكل عينة في كل نسخة متماثلة عن طريق إضافة قيم الخسارة ، ثم القسمة على حجم الدُفعة العام. في هذه الحالة: (2 + 3) / 4 = 1.25 و (4 + 5) / 4 = 2.25 .

إذا كنت تستخدم tf.metrics.Mean لتعقب الخسارة عبر نسختين متماثلتين ، فستكون النتيجة مختلفة. في هذا المثال ، ينتهي بك الأمر total 3.50 count 2 ، مما ينتج عنه total / count = 1.75 عندما يتم استدعاء result() في المقياس. يتم احتساب الخسارة باستخدام tf.keras.Metrics بواسطة عامل إضافي يساوي عدد النسخ المتماثلة المتزامنة.

دليل وأمثلة

فيما يلي بعض الأمثلة لاستخدام إستراتيجية التوزيع مع حلقات تدريب مخصصة:

  1. دليل التدريب الموزع
  2. مثال على DenseNet باستخدام MirroredStrategy .
  3. تم تدريب مثال BERT باستخدام MirroredStrategy و TPUStrategy . هذا المثال مفيد بشكل خاص لفهم كيفية التحميل من نقطة تفتيش وإنشاء نقاط تفتيش دورية أثناء التدريب الموزع وما إلى ذلك.
  4. مثال NCF تم تدريبه باستخدام MirroredStrategy يمكن تمكينه باستخدام علامة keras_use_ctl .
  5. تم تدريب مثال NMT باستخدام MirroredStrategy .

المزيد من الأمثلة المدرجة في دليل استراتيجية التوزيع .

الخطوات التالية

  • جرب tf.distribute.Strategy API على نماذجك.
  • قم بزيارة قسم الأداء في الدليل لمعرفة المزيد حول الاستراتيجيات والأدوات الأخرى التي يمكنك استخدامها لتحسين أداء نماذج TensorFlow.