![]() | ![]() | ![]() | ![]() |
ภาพรวม
tf.distribute.Strategy
API จัดเตรียมสิ่งที่เป็นนามธรรมสำหรับการกระจายการฝึกอบรมของคุณไปยังหน่วยประมวลผลต่างๆ เป้าหมายคือเพื่อให้ผู้ใช้เปิดใช้งานการฝึกอบรมแบบกระจายโดยใช้แบบจำลองและรหัสการฝึกอบรมที่มีอยู่โดยมีการเปลี่ยนแปลงน้อยที่สุด
บทช่วยสอนนี้ใช้ tf.distribute.MirroredStrategy
ซึ่งทำการจำลองแบบในกราฟพร้อมการฝึกอบรมแบบซิงโครนัสกับ GPU จำนวนมากในเครื่องเดียว โดยพื้นฐานแล้วจะคัดลอกตัวแปรทั้งหมดของโมเดลไปยังโปรเซสเซอร์แต่ละตัว จากนั้นจะใช้ การลดขนาดทั้งหมด เพื่อรวมการไล่ระดับสีจากโปรเซสเซอร์ทั้งหมดและใช้ค่าที่รวมกับสำเนาทั้งหมดของโมเดล
MirroredStrategy
เป็นหนึ่งในกลยุทธ์การกระจายหลายอย่างที่มีอยู่ในแกน TensorFlow คุณสามารถอ่านเกี่ยวกับกลยุทธ์เพิ่มเติมได้ที่ คู่มือกลยุทธ์การจัดจำหน่าย
Keras API
ตัวอย่างนี้ใช้ tf.keras
API เพื่อสร้างโมเดลและลูปการฝึกอบรม สำหรับลูปการฝึกอบรมแบบกำหนดเองโปรดดูที่ tf.distribute.Strategy พร้อม บทแนะนำ การฝึกอบรมลูป
นำเข้าการอ้างอิง
# Import TensorFlow and TensorFlow Datasets
import tensorflow_datasets as tfds
import tensorflow as tf
import os
print(tf.__version__)
2.3.0
ดาวน์โหลดชุดข้อมูล
ดาวน์โหลดชุดข้อมูล MNIST และโหลดจาก TensorFlow Datasets ส่งคืนชุดข้อมูลในรูปแบบ tf.data
การตั้งค่า with_info
เป็น True
จะรวมข้อมูลเมตาสำหรับชุดข้อมูลทั้งหมดซึ่งจะถูกบันทึกไว้ที่นี่เป็น info
เหนือสิ่งอื่นใดออบเจ็กต์ข้อมูลเมตานี้รวมถึงจำนวนรถไฟและตัวอย่างการทดสอบ
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`. Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
กำหนดกลยุทธ์การจัดจำหน่าย
สร้างวัตถุ MirroredStrategy
สิ่งนี้จะจัดการการแจกจ่ายและจัด tf.distribute.MirroredStrategy.scope
จัดการบริบท ( tf.distribute.MirroredStrategy.scope
) เพื่อสร้างโมเดลของคุณภายใน
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
ตั้งค่าท่อส่งข้อมูล
เมื่อฝึกโมเดลที่มี GPU หลายตัวคุณสามารถใช้พลังประมวลผลเพิ่มเติมได้อย่างมีประสิทธิภาพโดยการเพิ่มขนาดแบทช์ โดยทั่วไปให้ใช้ขนาดแบทช์ที่ใหญ่ที่สุดที่พอดีกับหน่วยความจำ GPU และปรับอัตราการเรียนรู้ให้เหมาะสม
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
ค่าพิกเซลซึ่งเป็น 0-255 ต้องปรับให้เป็นค่าปกติเป็นช่วง 0-1 กำหนดมาตราส่วนนี้ในฟังก์ชัน
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
ใช้ฟังก์ชันนี้กับข้อมูลการฝึกอบรมและการทดสอบสับเปลี่ยนข้อมูลการฝึกอบรมและจัด กลุ่มสำหรับการฝึกอบรม โปรดสังเกตว่าเรากำลังเก็บแคชข้อมูลการฝึกอบรมไว้ในหน่วยความจำเพื่อปรับปรุงประสิทธิภาพ
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
สร้างแบบจำลอง
สร้างและรวบรวมโมเดล Keras ในบริบทของ strategy.scope
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
กำหนดการเรียกกลับ
การเรียกกลับที่ใช้ที่นี่คือ:
- TensorBoard : การเรียกกลับนี้เขียนบันทึกสำหรับ TensorBoard ซึ่งช่วยให้คุณเห็นภาพกราฟ
- Model Checkpoint : การเรียกกลับนี้จะบันทึกโมเดลหลังจากทุกยุค
- เครื่องมือจัดกำหนดการอัตราการเรียนรู้ : เมื่อใช้การโทรกลับนี้คุณสามารถกำหนดเวลาให้อัตราการเรียนรู้เปลี่ยนแปลงได้หลังจากทุกยุค / ชุด
เพื่อเป็นภาพประกอบให้เพิ่มการติดต่อกลับเพื่อแสดง อัตราการเรียนรู้ ในสมุดบันทึก
# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.lr.numpy()))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
ฝึกอบรมและประเมิน
ตอนนี้ให้ฝึกโมเดลตามปกติโดยเรียก fit
กับโมเดลและส่งผ่านชุดข้อมูลที่สร้างขึ้นเมื่อเริ่มบทแนะนำ ขั้นตอนนี้ก็เหมือนกันไม่ว่าคุณจะกระจายการฝึกอบรมหรือไม่
model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0047s vs `on_train_batch_end` time: 0.0316s). Check your callbacks. 932/938 [============================>.] - ETA: 0s - loss: 0.1947 - accuracy: 0.9441 Learning rate for epoch 1 is 0.0010000000474974513 938/938 [==============================] - 4s 4ms/step - loss: 0.1939 - accuracy: 0.9442 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2/12 935/938 [============================>.] - ETA: 0s - loss: 0.0636 - accuracy: 0.9811 Learning rate for epoch 2 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0634 - accuracy: 0.9812 Epoch 3/12 936/938 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9864 Learning rate for epoch 3 is 0.0010000000474974513 938/938 [==============================] - 2s 3ms/step - loss: 0.0439 - accuracy: 0.9864 Epoch 4/12 937/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9936 Learning rate for epoch 4 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0234 - accuracy: 0.9936 Epoch 5/12 932/938 [============================>.] - ETA: 0s - loss: 0.0204 - accuracy: 0.9948 Learning rate for epoch 5 is 9.999999747378752e-05 938/938 [==============================] - 3s 3ms/step - loss: 0.0204 - accuracy: 0.9948 Epoch 6/12 919/938 [============================>.] - ETA: 0s - loss: 0.0188 - accuracy: 0.9951 Learning rate for epoch 6 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0187 - accuracy: 0.9951 Epoch 7/12 921/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960 Learning rate for epoch 7 is 9.999999747378752e-05 938/938 [==============================] - 2s 3ms/step - loss: 0.0171 - accuracy: 0.9960 Epoch 8/12 931/938 [============================>.] - ETA: 0s - loss: 0.0147 - accuracy: 0.9970 Learning rate for epoch 8 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0147 - accuracy: 0.9970 Epoch 9/12 938/938 [==============================] - ETA: 0s - loss: 0.0144 - accuracy: 0.9970 Learning rate for epoch 9 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0144 - accuracy: 0.9970 Epoch 10/12 924/938 [============================>.] - ETA: 0s - loss: 0.0143 - accuracy: 0.9971 Learning rate for epoch 10 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0142 - accuracy: 0.9971 Epoch 11/12 937/938 [============================>.] - ETA: 0s - loss: 0.0140 - accuracy: 0.9972 Learning rate for epoch 11 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0140 - accuracy: 0.9972 Epoch 12/12 923/938 [============================>.] - ETA: 0s - loss: 0.0139 - accuracy: 0.9973 Learning rate for epoch 12 is 9.999999747378752e-06 938/938 [==============================] - 2s 3ms/step - loss: 0.0139 - accuracy: 0.9973 <tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>
ดังที่คุณเห็นด้านล่างจุดตรวจจะได้รับการบันทึก
# check the checkpoint directory
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
หากต้องการดูประสิทธิภาพของโมเดลให้โหลดจุดตรวจสอบล่าสุดและโทร evaluate
ข้อมูลการทดสอบ
โทร evaluate
ก่อนใช้ชุดข้อมูลที่เหมาะสม
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 6ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
หากต้องการดูผลลัพธ์คุณสามารถดาวน์โหลดและดูบันทึก TensorBoard ได้ที่เทอร์มินัล
$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K 4.0K train
ส่งออกไปยัง SavedModel
ส่งออกกราฟและตัวแปรไปยังรูปแบบ SavedModel แพลตฟอร์มที่ไม่เชื่อเรื่องพระเจ้า หลังจากบันทึกโมเดลของคุณแล้วคุณสามารถโหลดได้โดยมีหรือไม่มีขอบเขต
path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: This property should not be used in TensorFlow 2.0, as updates are applied automatically. INFO:tensorflow:Assets written to: saved_model/assets INFO:tensorflow:Assets written to: saved_model/assets
โหลดโมเดลโดยไม่ต้อง strategy.scope
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 3ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
โหลดโมเดลด้วย strategy.scope
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 1s 4ms/step - loss: 0.0393 - accuracy: 0.9864 Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991
ตัวอย่างและบทช่วยสอน
นี่คือตัวอย่างบางส่วนสำหรับการใช้กลยุทธ์การกระจายกับ keras fit / compile:
- ตัวอย่าง Transformer ที่ ฝึกโดยใช้
tf.distribute.MirroredStrategy
- ตัวอย่าง NCF ที่ ฝึกโดยใช้
tf.distribute.MirroredStrategy
ตัวอย่างเพิ่มเติมที่ระบุไว้ใน คู่มือกลยุทธ์การจัดจำหน่าย
ขั้นตอนถัดไป
- อ่าน คู่มือกลยุทธ์การจัดจำหน่าย
- อ่านบทแนะนำ การฝึกอบรมแบบกระจายด้วย Custom Training Loops
- ไปที่ ส่วนประสิทธิภาพ ในคำแนะนำเพื่อเรียนรู้เพิ่มเติมเกี่ยวกับกลยุทธ์และ เครื่องมือ อื่น ๆ ที่คุณสามารถใช้เพื่อเพิ่มประสิทธิภาพของโมเดล TensorFlow ของคุณ