![]() | ![]() | ![]() | ![]() |
סקירה כללית
ממשק ה- API של tf.distribute.Strategy
מספק הפשטה להפצת ההדרכה שלך על פני יחידות עיבוד מרובות. המטרה היא לאפשר למשתמשים לאפשר אימונים מבוזרים באמצעות מודלים קיימים וקוד אימונים, עם שינויים מינימליים.
הדרכה זו משתמשת ב- tf.distribute.MirroredStrategy
, שעושה שכפול בגרף עם אימונים סינכרוניים במכשירי GPU רבים במחשב אחד. בעיקרו של דבר, הוא מעתיק את כל המשתנים של המודל לכל מעבד. לאחר מכן, היא משתמשת בהפחתה מלאה כדי לשלב את הדרגתיות מכל המעבדים ומחילה את הערך המשולב על כל עותקי הדגם.
MirroredStrategy
היא אחת מכמה אסטרטגיות ההפצה הקיימות בליבת TensorFlow. תוכלו לקרוא על אסטרטגיות נוספות במדריך לאסטרטגיה להפצה .
ממשק API של קרס
דוגמה זו משתמשת בממשק ה- API של tf.keras
לבניית המודל ולולאת האימון. לולאות אימון מותאמות אישית, עיין במדריך tf.distribute.Strategy עם לולאות אימון .
ייבוא תלות
# Import TensorFlow and TensorFlow Datasets
import tensorflow_datasets as tfds
import tensorflow as tf
import os
print(tf.__version__)
2.3.0
הורד את מערך הנתונים
הורד את מערך הנתונים MNIST וטען אותו ממערכי הנתונים של TensorFlow . זה מחזיר מערך נתונים בפורמט tf.data
.
הגדרה with_info
ל- True
כוללת את המטא נתונים עבור כל מערך הנתונים, שנשמר כאן info
. בין היתר, אובייקט מטא-נתונים זה כולל את מספר דוגמאות הרכבת והמבחנים.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`. Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
הגדר אסטרטגיית הפצה
צור אובייקט MirroredStrategy
. זה יטפל בהפצה ויספק מנהל הקשר ( tf.distribute.MirroredStrategy.scope
) לבנות את המודל שלך בפנים.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
הגדרת צינור קלט
כשאתה מאמן מודל עם מספר GPUs, אתה יכול להשתמש בכוח המחשוב הנוסף ביעילות על ידי הגדלת גודל האצווה. באופן כללי, השתמש בגודל האצווה הגדול ביותר שמתאים לזיכרון ה- GPU, וכוון את קצב הלמידה בהתאם.
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
יש לנרמל את ערכי הפיקסלים, שהם 0-255, לטווח 0-1 . הגדר סולם זה בפונקציה.
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
החל פונקציה זו על נתוני האימון והבדיקה, ערבב את נתוני האימון ואצווה אותם לאימון . שים לב שאנו שומרים גם מטמון בזיכרון של נתוני האימון כדי לשפר את הביצועים.
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
צור את המודל
צור והרכיב את מודל Keras בהקשר של strategy.scope
.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
הגדר את השיחות החוזרות
השיחות החוזרות המשמשות כאן הן:
- TensorBoard : התקשרות חוזרת זו כותבת יומן עבור TensorBoard המאפשר לך לדמיין את הגרפים.
- נקודת ביקורת מודל : התקשרות חוזרת זו שומרת את המודל לאחר כל תקופה.
- מתזמן שיעורי למידה: באמצעות התקשרות חוזרת זו, אתה יכול לתזמן את שינוי הלמידה לאחר כל תקופה / אצווה.
למטרות המחשה, הוסף התקשרות חוזרת להדפסה כדי להציג את קצב הלמידה במחברת.
# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.lr.numpy()))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
התאמן והערך
כעת, הכשיר את המודל בצורה הרגילה, קרא את fit
למודל והעביר את מערך הנתונים שנוצר בתחילת ההדרכה. שלב זה זהה בין אם אתה מפיץ את ההדרכה ובין אם לא.
model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. 932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441 Learning rate for epoch 1 is 0.0010000000474974513 938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2/12 935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811 Learning rate for epoch 2 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812 Epoch 3/12 936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864 Learning rate for epoch 3 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864 Epoch 4/12 937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936 Learning rate for epoch 4 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936 Epoch 5/12 932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948 Learning rate for epoch 5 is 9.999999747378752e-05 938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948 Epoch 6/12 919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951 Learning rate for epoch 6 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951 Epoch 7/12 921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960 Learning rate for epoch 7 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960 Epoch 8/12 931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970 Learning rate for epoch 8 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970 Epoch 9/12 938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970 Learning rate for epoch 9 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970 Epoch 10/12 924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971 Learning rate for epoch 10 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971 Epoch 11/12 937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972 Learning rate for epoch 11 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972 Epoch 12/12 923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973 Learning rate for epoch 12 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973 <tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>
כפי שניתן לראות למטה, נקודות הביקורת נשמרות.
# check the checkpoint directory
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
כדי לראות איך המודל לבצע, לטעון את המחסום השיחה האחרונה evaluate
על נתוני הבדיקה.
התקשר evaluate
כמו קודם באמצעות מערכי נתונים מתאימים.
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
כדי לראות את הפלט, אתה יכול להוריד ולהציג את יומני TensorBoard במסוף.
$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K 4.0K train
ייצא ל- SavedModel
ייצא את הגרף ואת המשתנים לפורמט SavedModel לפלטפורמה-אגנוסטית. לאחר שמירת המודל שלך, תוכל לטעון אותו עם או בלי ההיקף.
path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: saved_model/assets INFO:tensorflow:Assets written to: saved_model/assets
טען את המודל ללא strategy.scope
.
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
טען את המודל strategy.scope
.
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
דוגמאות והדרכות
להלן מספר דוגמאות לשימוש באסטרטגיית הפצה עם התאמת / הידור של keras:
- שנאי למשל מאומנים באמצעות
tf.distribute.MirroredStrategy
- דוגמה ל- NCF שהוכשרה באמצעות
tf.distribute.MirroredStrategy
.
דוגמאות נוספות המופיעות במדריך האסטרטגיה להפצה
הצעדים הבאים
- קרא את מדריך אסטרטגיית ההפצה .
- קרא את הדרכת ההדרכה המבוזרת עם לולאות אימון מותאמות אישית .
- בקר בסעיף הביצועים במדריך כדי ללמוד עוד על אסטרטגיות וכלים אחרים שבהם תוכל להשתמש כדי לייעל את הביצועים של דגמי TensorFlow שלך.