หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

การฝึกอบรมผู้ปฏิบัติงานหลายคนด้วย Keras

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

ภาพรวม

บทช่วยสอนนี้สาธิตการฝึกอบรมแบบกระจายผู้ปฏิบัติงานหลายคนด้วยโมเดล tf.distribute.Strategy โดยใช้ tf.distribute.Strategy API โดยเฉพาะ tf.distribute.experimental.MultiWorkerMirroredStrategy ด้วยความช่วยเหลือของกลยุทธ์นี้โมเดล Keras ที่ออกแบบมาเพื่อทำงานบนคนงานคนเดียวสามารถทำงานกับคนงานหลายคนได้อย่างราบรื่นโดยมีการเปลี่ยนแปลงรหัสเพียงเล็กน้อย

การฝึกอบรมแบบกระจายใน คู่มือ TensorFlow มีให้สำหรับภาพรวมของกลยุทธ์การแจกจ่ายที่ TensorFlow สนับสนุนสำหรับผู้ที่สนใจในความเข้าใจที่ลึกซึ้งยิ่งขึ้นเกี่ยวกับ tf.distribute.Strategy APIs

ติดตั้ง

ขั้นแรกการนำเข้าที่จำเป็นบางอย่าง

import json
import os
import sys

ก่อนที่จะนำเข้า TensorFlow ให้ทำการเปลี่ยนแปลงสภาพแวดล้อมเล็กน้อย

ปิดการใช้งาน GPU ทั้งหมด ซึ่งจะช่วยป้องกันข้อผิดพลาดที่เกิดจากพนักงานทุกคนพยายามใช้ GPU เดียวกัน สำหรับการใช้งานจริงพนักงานแต่ละคนจะอยู่บนเครื่องอื่น

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

รีเซ็ตตัวแปรสภาพแวดล้อม TF_CONFIG คุณจะเห็นข้อมูลเพิ่มเติมในภายหลัง

os.environ.pop('TF_CONFIG', None)

ตรวจสอบให้แน่ใจว่าไดเร็กทอรีปัจจุบันอยู่บนเส้นทางของ python ซึ่งจะช่วยให้สมุดบันทึกสามารถนำเข้าไฟล์ที่เขียนโดย %%writefile ภายหลัง

if '.' not in sys.path:
  sys.path.insert(0, '.')

ตอนนี้นำเข้า TensorFlow

import tensorflow as tf

นิยามชุดข้อมูลและแบบจำลอง

จากนั้นสร้างไฟล์ mnist.py ด้วยโมเดลและการตั้งค่าชุดข้อมูลอย่างง่าย ไฟล์ python นี้จะถูกใช้โดยกระบวนการของผู้ปฏิบัติงานในบทช่วยสอนนี้:

%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # You need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model
Writing mnist.py

ลองฝึกแบบจำลองสำหรับช่วงเวลาเล็ก ๆ และสังเกตผลลัพธ์ของคนงานคนเดียวเพื่อให้แน่ใจว่าทุกอย่างทำงานได้อย่างถูกต้อง เมื่อการฝึกดำเนินไปการสูญเสียควรลดลงและความแม่นยำควรเพิ่มขึ้น

import mnist

batch_size = 64
single_worker_dataset = mnist.mnist_dataset(batch_size)
single_worker_model = mnist.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/3
70/70 [==============================] - 1s 14ms/step - loss: 2.2700 - accuracy: 0.2304
Epoch 2/3
70/70 [==============================] - 1s 14ms/step - loss: 2.1825 - accuracy: 0.4569
Epoch 3/3
70/70 [==============================] - 1s 13ms/step - loss: 2.0803 - accuracy: 0.5958

<tensorflow.python.keras.callbacks.History at 0x7fde81f99160>

การกำหนดค่าผู้ปฏิบัติงานหลายคน

ตอนนี้เข้าสู่โลกของการฝึกอบรมผู้ปฏิบัติงานหลายคน ใน TensorFlow ตัวแปรสภาพแวดล้อม TF_CONFIG จำเป็นสำหรับการฝึกอบรมบนเครื่องหลายเครื่องซึ่งแต่ละเครื่องอาจมีบทบาทที่แตกต่างกัน TF_CONFIG เป็นสตริง JSON ที่ใช้เพื่อระบุคอนฟิกูเรชันคลัสเตอร์บนผู้ปฏิบัติงานแต่ละคนที่เป็นส่วนหนึ่งของคลัสเตอร์

นี่คือตัวอย่างการกำหนดค่า:

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

นี่คือ TF_CONFIG เดียวกันที่ทำให้เป็นอนุกรมเป็นสตริง JSON:

json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

TF_CONFIG มีสององค์ประกอบ: cluster และ task

  • cluster เหมือนกันสำหรับคนงานทั้งหมดและให้ข้อมูลเกี่ยวกับคลัสเตอร์การฝึกอบรมซึ่งเป็นคำสั่งที่ประกอบด้วยงานประเภทต่างๆเช่น worker ในการฝึกอบรมผู้ปฏิบัติงานหลายคนด้วย MultiWorkerMirroredStrategy มักจะมี worker หนึ่งที่ต้องรับผิดชอบมากกว่าเล็กน้อยเช่นการบันทึกจุดตรวจสอบและการเขียนไฟล์สรุปสำหรับ TensorBoard นอกเหนือจากสิ่งที่ worker ประจำทำ คนงานดังกล่าวเรียกว่า chief คนงานและเป็นธรรมเนียมที่ worker มี index 0 จะได้รับการแต่งตั้งให้เป็นหัวหน้า worker (อันที่จริงนี่คือวิธีการใช้ tf.distribute.Strategy )

  • task ให้ข้อมูลของงานปัจจุบันและแตกต่างกันไปในแต่ละคน ระบุ type และ index ของผู้ปฏิบัติงานนั้น

ในตัวอย่างนี้คุณตั้งค่า type งานเป็น "worker" และ index งานเป็น 0 เครื่องจักรนี้เป็นคนงานเครื่องแรกและจะได้รับการแต่งตั้งให้เป็นหัวหน้าคนงานและทำงานมากกว่าเครื่องอื่น ๆ โปรดทราบว่าเครื่องอื่น ๆ จะต้องมีชุดตัวแปรสภาพแวดล้อม TF_CONFIG เช่นกันและควรมี cluster dict เหมือนกัน แต่ type งานหรือ index งานต่างกันขึ้นอยู่กับบทบาทของเครื่องเหล่านั้น

เพื่อเป็นภาพประกอบบทช่วยสอนนี้จะแสดงวิธีการตั้งค่า TF_CONFIG กับคนงาน 2 คนใน localhost ในทางปฏิบัติผู้ใช้จะสร้างผู้ปฏิบัติงานหลายคนบนที่อยู่ / พอร์ต IP ภายนอกและตั้งค่า TF_CONFIG กับผู้ปฏิบัติงานแต่ละคนอย่างเหมาะสม

ในตัวอย่างนี้คุณจะใช้คนงาน 2 คน TF_CONFIG ของคนงานคนแรกแสดงไว้ด้านบน สำหรับผู้ปฏิบัติงานคนที่สองคุณจะตั้งค่า tf_config['task']['index']=1

ด้านบน tf_config เป็นเพียงตัวแปรท้องถิ่นใน python ในการใช้งานจริงเพื่อกำหนดค่าการฝึกอบรมพจนานุกรมนี้จำเป็นต้องทำให้เป็นอนุกรมเป็น JSON และวางไว้ในตัวแปรสภาพแวดล้อม TF_CONFIG

ตัวแปรสภาพแวดล้อมและกระบวนการย่อยในสมุดบันทึก

กระบวนการย่อยสืบทอดตัวแปรสภาพแวดล้อมจากพาเรนต์ ดังนั้นหากคุณตั้งค่าตัวแปรสภาพแวดล้อมในกระบวนการ jupyter notebook นี้:

os.environ['GREETINGS'] = 'Hello TensorFlow!'

คุณสามารถเข้าถึงตัวแปรสภาพแวดล้อมจากกระบวนการย่อย:

echo ${GREETINGS}
Hello TensorFlow!

ในส่วนถัดไปคุณจะใช้สิ่งนี้เพื่อส่ง TF_CONFIG ไปยังกระบวนการย่อยของผู้ปฏิบัติงาน คุณจะไม่มีวันเปิดงานด้วยวิธีนี้จริง ๆ แต่ก็เพียงพอแล้วสำหรับวัตถุประสงค์ของบทช่วยสอนนี้: เพื่อแสดงให้เห็นถึงตัวอย่างที่มีผู้ปฏิบัติงานจำนวนน้อยที่สุด

เลือกกลยุทธ์ที่เหมาะสม

ใน TensorFlow มีการฝึกอบรมแบบกระจายสองรูปแบบหลัก:

  • การฝึกอบรมแบบซิงโครนัสซึ่งขั้นตอนของการฝึกอบรมจะถูกซิงค์ระหว่างคนงานและแบบจำลองและ
  • การฝึกอบรมแบบอะซิงโครนัสโดยที่ขั้นตอนการฝึกอบรมไม่ได้รับการซิงค์อย่างเคร่งครัด

MultiWorkerMirroredStrategy ซึ่งเป็นกลยุทธ์ที่แนะนำสำหรับการฝึกอบรมผู้ปฏิบัติงานหลายคนแบบซิงโครนัสจะแสดงให้เห็นในคู่มือนี้ ในการฝึกโมเดลให้ใช้อินสแตนซ์ของ tf.distribute.experimental.MultiWorkerMirroredStrategy

MultiWorkerMirroredStrategy สร้างสำเนาของตัวแปรทั้งหมดในเลเยอร์ของโมเดลบนอุปกรณ์แต่ละเครื่องสำหรับคนงานทั้งหมด มันใช้ CollectiveOps ซึ่งเป็น TensorFlow op สำหรับการสื่อสารแบบรวมเพื่อรวมการไล่ระดับสีและทำให้ตัวแปรซิงค์กัน คู่มือ tf.distribute.Strategy มีรายละเอียดเพิ่มเติมเกี่ยวกับกลยุทธ์นี้

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CollectiveCommunication.AUTO

MultiWorkerMirroredStrategy จัดเตรียมการใช้งานหลายอย่างผ่านพารามิเตอร์ CollectiveCommunication RING สร้างกลุ่มที่ใช้วงแหวนโดยใช้ gRPC เป็นเลเยอร์การสื่อสารข้ามโฮสต์ NCCL ใช้ NCCL ของ Nvidia เพื่อใช้งาน Collectives AUTO เลื่อนตัวเลือกไปที่รันไทม์ ทางเลือกที่ดีที่สุดของการใช้งานร่วมกันขึ้นอยู่กับจำนวนและประเภทของ GPU และการเชื่อมต่อเครือข่ายในคลัสเตอร์

ฝึกโมเดล

ด้วยการรวม tf.distribute.Strategy API เข้ากับ tf.keras การเปลี่ยนแปลงเดียวที่คุณจะทำเพื่อแจกจ่ายการฝึกอบรมให้กับผู้ปฏิบัติงานหลายคนคือการรวมการสร้างโมเดลและ model.compile() ภายใน strategy.scope() ขอบเขตของกลยุทธ์การกระจายจะกำหนดวิธีการและสถานที่ที่สร้างตัวแปรและในกรณีของ MultiWorkerMirroredStrategy ตัวแปรที่สร้างขึ้นคือ MirroredVariable s และจะถูกจำลองแบบในแต่ละคนงาน

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist.build_and_compile_cnn_model()

ในการทำงานกับ MultiWorkerMirroredStrategy คุณจะต้องเรียกใช้กระบวนการของผู้ปฏิบัติงานและส่ง TF_CONFIG ไปให้พวกเขา

เช่นเดียวกับไฟล์ mnist.py เขียนไว้ก่อนหน้านี้คือ main.py ที่แต่ละคนจะเรียกใช้:

%%writefile main.py

import os
import json

import tensorflow as tf
import mnist

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist.build_and_compile_cnn_model()


multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
Writing main.py

ในข้อมูลโค้ดด้านบนโปรดทราบว่า global_batch_size ซึ่งส่งต่อไปยัง Dataset.batch ถูกตั้งค่าเป็น per_worker_batch_size * num_workers เพื่อให้แน่ใจว่าผู้ปฏิบัติงานแต่ละคนประมวลผลชุดของตัวอย่าง per_worker_batch_size โดยไม่คำนึงถึงจำนวนคนงาน

ไดเร็กทอรีปัจจุบันมีทั้งไฟล์ Python:

ls *.py
main.py
mnist.py

ดังนั้น json-serialize TF_CONFIG และเพิ่มลงในตัวแปรสภาพแวดล้อม:

os.environ['TF_CONFIG'] = json.dumps(tf_config)

ตอนนี้คุณสามารถเรียกใช้กระบวนการของผู้ปฏิบัติงานที่จะเรียกใช้ main.py และใช้ TF_CONFIG :

# first kill any previous runs
%killbgscripts
All background processes were killed.

python main.py &> job_0.log

มีบางสิ่งที่ควรทราบเกี่ยวกับคำสั่งดังกล่าว:

  1. มันใช้ %%bash ซึ่งเป็น "magic" ของโน้ตบุ๊ก เพื่อเรียกใช้คำสั่ง bash
  2. ใช้แฟ --bg เพื่อรันกระบวนการ bash ในเบื้องหลังเนื่องจากผู้ปฏิบัติงานนี้จะไม่ยุติ รอคนงานทั้งหมดก่อนที่จะเริ่ม

กระบวนการของผู้ปฏิบัติงานที่อยู่เบื้องหลังจะไม่พิมพ์ผลลัพธ์ไปยังสมุดบันทึกนี้ดังนั้น &> เปลี่ยนเส้นทางเอาต์พุตไปยังไฟล์เพื่อให้คุณสามารถดูว่าเกิดอะไรขึ้น

ดังนั้นรอสักครู่เพื่อให้กระบวนการเริ่มต้นขึ้น:

import time
time.sleep(10)

ตอนนี้ดูสิ่งที่ส่งออกไปยังไฟล์บันทึกของพนักงานจนถึงตอนนี้:

cat job_0.log
2020-09-12 01:27:14.389683: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-09-12 01:27:15.736635: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2020-09-12 01:27:16.990546: E tensorflow/stream_executor/cuda/cuda_driver.cc:314] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2020-09-12 01:27:16.990628: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-1064447366
2020-09-12 01:27:16.990639: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-1064447366
2020-09-12 01:27:16.990757: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 450.51.5
2020-09-12 01:27:16.990801: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 450.51.5
2020-09-12 01:27:16.990811: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 450.51.5
2020-09-12 01:27:16.991256: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-09-12 01:27:17.002048: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2000179999 Hz
2020-09-12 01:27:17.002560: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x44b82a0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-09-12 01:27:17.002598: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-09-12 01:27:17.009677: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456}
2020-09-12 01:27:17.010257: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:405] Started server with target: grpc://localhost:12345

บรรทัดสุดท้ายของไฟล์บันทึกควรระบุว่า: Started server with target: grpc://localhost:12345 คนงานคนแรกพร้อมแล้วและกำลังรอให้คนงานคนอื่นพร้อมที่จะดำเนินการต่อ

ดังนั้นอัปเดต tf_config สำหรับกระบวนการของพนักงานคนที่สองที่จะรับ:

85f4464 ดี

ตอนนี้เปิดตัวคนงานคนที่สอง การดำเนินการนี้จะเริ่มการฝึกอบรมเนื่องจากพนักงานทุกคนทำงานอยู่ (ดังนั้นจึงไม่จำเป็นต้องมีเบื้องหลังกระบวนการนี้):

python main.py
Epoch 1/3
70/70 [==============================] - 4s 55ms/step - loss: 2.2622 - accuracy: 0.1663
Epoch 2/3
70/70 [==============================] - 4s 53ms/step - loss: 2.1959 - accuracy: 0.2958
Epoch 3/3
70/70 [==============================] - 4s 55ms/step - loss: 2.1158 - accuracy: 0.4607

2020-09-12 01:27:24.523408: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-09-12 01:27:25.851591: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2020-09-12 01:27:26.994525: E tensorflow/stream_executor/cuda/cuda_driver.cc:314] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2020-09-12 01:27:26.994608: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-1064447366
2020-09-12 01:27:26.994619: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-1064447366
2020-09-12 01:27:26.994733: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 450.51.5
2020-09-12 01:27:26.994779: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 450.51.5
2020-09-12 01:27:26.994793: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 450.51.5
2020-09-12 01:27:26.995232: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-09-12 01:27:27.003492: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2000179999 Hz
2020-09-12 01:27:27.003991: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5b7e150 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-09-12 01:27:27.004027: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-09-12 01:27:27.010851: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456}
2020-09-12 01:27:27.011365: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:405] Started server with target: grpc://localhost:23456
WARNING:tensorflow:`eval_fn` is not passed in. The `worker_fn` will be used if an "evaluator" task exists in the cluster.
WARNING:tensorflow:`eval_strategy` is not passed in. No distribution strategy will be used for evaluation.
2020-09-12 01:27:27.936589: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:521] In AUTO-mode, and switching to DATA-based sharding, instead of FILE-based sharding as we cannot find appropriate reader dataset op(s) to shard. Error: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

ตอนนี้ถ้าคุณตรวจสอบบันทึกที่เขียนโดยคนงานคนแรกอีกครั้งคุณจะเห็นว่ามันเข้าร่วมในการฝึกอบรมแบบจำลองนั้น:

cat job_0.log
2020-09-12 01:27:14.389683: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-09-12 01:27:15.736635: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2020-09-12 01:27:16.990546: E tensorflow/stream_executor/cuda/cuda_driver.cc:314] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2020-09-12 01:27:16.990628: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-1064447366
2020-09-12 01:27:16.990639: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-1064447366
2020-09-12 01:27:16.990757: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 450.51.5
2020-09-12 01:27:16.990801: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 450.51.5
2020-09-12 01:27:16.990811: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 450.51.5
2020-09-12 01:27:16.991256: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-09-12 01:27:17.002048: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2000179999 Hz
2020-09-12 01:27:17.002560: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x44b82a0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-09-12 01:27:17.002598: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-09-12 01:27:17.009677: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:12345, 1 -> localhost:23456}
2020-09-12 01:27:17.010257: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:405] Started server with target: grpc://localhost:12345
WARNING:tensorflow:`eval_fn` is not passed in. The `worker_fn` will be used if an "evaluator" task exists in the cluster.
WARNING:tensorflow:`eval_strategy` is not passed in. No distribution strategy will be used for evaluation.
2020-09-12 01:27:27.934554: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:521] In AUTO-mode, and switching to DATA-based sharding, instead of FILE-based sharding as we cannot find appropriate reader dataset op(s) to shard. Error: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
Epoch 1/3
70/70 [==============================] - 4s 55ms/step - loss: 2.2622 - accuracy: 0.1663
Epoch 2/3
70/70 [==============================] - 4s 53ms/step - loss: 2.1959 - accuracy: 0.2958
Epoch 3/3
70/70 [==============================] - 4s 55ms/step - loss: 2.1158 - accuracy: 0.4607

ไม่น่าแปลกใจที่สิ่งนี้ ทำงานช้า กว่าการทดสอบในตอนต้นของบทช่วยสอนนี้ การเรียกใช้คนงานหลายคนในเครื่องเดียวจะเพิ่มค่าโสหุ้ยเท่านั้น เป้าหมายในที่นี้ไม่ได้อยู่ที่การปรับปรุงเวลาการฝึกอบรม แต่เป็นเพียงตัวอย่างของการฝึกอบรมผู้ปฏิบัติงานหลายคน

# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.

การฝึกอบรมพนักงานหลายคนในเชิงลึก

จนถึงตอนนี้บทแนะนำนี้ได้แสดงให้เห็นถึงการตั้งค่าผู้ปฏิบัติงานหลายคนขั้นพื้นฐาน ส่วนที่เหลือของเอกสารนี้จะดูรายละเอียดปัจจัยอื่น ๆ ที่อาจเป็นประโยชน์หรือสำคัญสำหรับกรณีการใช้งานจริง

การแยกชุดข้อมูล

ในการฝึกอบรมผู้ปฏิบัติงานหลายคนจำเป็นต้องมีการแบ่งชุดข้อมูลเพื่อให้แน่ใจว่ามีการบรรจบกันและประสิทธิภาพ

ตัวอย่างในส่วนก่อนหน้านี้อาศัยการชาร์ตอัตโนมัติเริ่มต้นที่จัดเตรียมโดย tf.distribute.Strategy API คุณสามารถควบคุมการชาร์ดได้โดยตั้งค่า tf.data.experimental.AutoShardPolicy ของ tf.data.experimental.DistributeOptions หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับการชาร์ดอัตโนมัติโปรดดู คู่มืออินพุตแบบกระจาย

นี่คือตัวอย่างสั้น ๆ ของวิธีปิดการชาร์ดอัตโนมัติดังนั้นการจำลองแต่ละตัวจึงประมวลผลทุกตัวอย่าง (ไม่แนะนำ):

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

global_batch_size = 64
multi_worker_dataset = mnist.mnist_dataset(batch_size=64)
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

การประเมินผล

หากคุณส่ง validation_data ไปยัง model.fit มันจะสลับกันระหว่างการฝึกอบรมและการประเมินผลสำหรับแต่ละยุค การประเมินการใช้ validation_data จะกระจายไปตามกลุ่มคนงานเดียวกันและผลการประเมินจะถูกรวบรวมและพร้อมใช้งานสำหรับคนงานทุกคน เช่นเดียวกับการฝึกอบรมชุดข้อมูลการตรวจสอบความถูกต้องจะถูกแบ่งโดยอัตโนมัติที่ระดับไฟล์ คุณต้องกำหนดขนาดแบทช์ส่วนกลางในชุดข้อมูลการตรวจสอบความถูกต้องและตั้งค่า validation_steps ขอแนะนำให้ใช้ชุดข้อมูลซ้ำสำหรับการประเมิน

หรือคุณสามารถสร้างงานอื่นที่อ่านจุดตรวจเป็นระยะและเรียกใช้การประเมินผล นี่คือสิ่งที่ Estimator ทำ แต่นี่ไม่ใช่วิธีที่แนะนำในการประเมินผลดังนั้นจึงละเว้นรายละเอียด

คาดการณ์

model.predict ปัจจุบันใช้ไม่ได้กับ MultiWorkerMirroredStrategy.

ประสิทธิภาพ

ตอนนี้คุณมีโมเดล Keras ที่ตั้งค่าให้ทำงานในหลายคนด้วย MultiWorkerMirroredStrategy คุณสามารถลองใช้เทคนิคต่อไปนี้เพื่อปรับแต่งประสิทธิภาพของการฝึกอบรมผู้ปฏิบัติงานหลายคนด้วย MultiWorkerMirroredStrategy

  • MultiWorkerMirroredStrategy เสนอ การใช้งานการสื่อสาร โดยรวมที่หลากหลาย RING สร้างกลุ่มที่ใช้วงแหวนโดยใช้ gRPC เป็นเลเยอร์การสื่อสารข้ามโฮสต์ NCCL ใช้ NCCL ของ Nvidia เพื่อใช้งาน Collectives AUTO เลื่อนตัวเลือกไปที่รันไทม์ ทางเลือกที่ดีที่สุดของการใช้งานร่วมกันขึ้นอยู่กับจำนวนและประเภทของ GPU และการเชื่อมต่อเครือข่ายในคลัสเตอร์ ในการแทนที่ตัวเลือกอัตโนมัติให้ระบุค่าที่ถูกต้องให้กับพารามิเตอร์ communication ของตัว MultiWorkerMirroredStrategy เช่น communication=tf.distribute.experimental.CollectiveCommunication.NCCL
  • ส่งตัวแปรเป็น tf.float ถ้าเป็นไปได้ โมเดล ResNet อย่างเป็นทางการมี ตัวอย่าง วิธีการนี้

ความทนทานต่อความผิดพลาด

ในการฝึกอบรมแบบซิงโครนัสคลัสเตอร์จะล้มเหลวหากผู้ปฏิบัติงานคนใดคนหนึ่งล้มเหลวและไม่มีกลไกการกู้คืนความล้มเหลว การใช้ Keras กับ tf.distribute.Strategy มาพร้อมกับข้อดีของการยอมรับข้อผิดพลาดในกรณีที่คนงานเสียชีวิตหรือไม่มั่นคง คุณทำได้โดยรักษาสถานะการฝึกอบรมไว้ในระบบไฟล์แบบกระจายที่คุณเลือกเช่นเมื่อรีสตาร์ทอินสแตนซ์ที่ล้มเหลวก่อนหน้านี้หรือถูกจองไว้ก่อนหน้านี้สถานะการฝึกอบรมจะถูกกู้คืน

เนื่องจากคนงานทั้งหมดจะซิงค์กันในแง่ของขั้นตอนการฝึกอบรมและขั้นตอนต่างๆคนงานคนอื่น ๆ จึงต้องรอให้คนงานที่ล้มเหลวหรือถูกจองไว้ก่อนจึงจะรีสตาร์ทเพื่อดำเนินการต่อ

ModelCheckpoint โทรกลับ

ModelCheckpoint เรียกกลับ ModelCheckpoint ไม่ให้ฟังก์ชันการยอมรับข้อผิดพลาดอีกต่อไปโปรดใช้การโทรกลับ BackupAndRestore แทน

การเรียกกลับ ModelCheckpoint ยังสามารถใช้เพื่อบันทึกจุดตรวจได้ แต่ด้วยเหตุนี้หากการฝึกถูกขัดจังหวะหรือเสร็จสิ้นในการฝึกต่อจากจุดตรวจผู้ใช้ต้องรับผิดชอบในการโหลดโมเดลด้วยตนเอง

ผู้ใช้สามารถเลือกที่จะบันทึกและกู้คืนโมเดล / น้ำหนักนอกการเรียกกลับ ModelCheckpoint

การประหยัดและการโหลดโมเดล

ในการบันทึกโมเดลของคุณโดยใช้ model.save หรือ tf.saved_model.save ปลายทางสำหรับการบันทึกจะต้องแตกต่างกันสำหรับผู้ปฏิบัติงานแต่ละคน สำหรับผู้ปฏิบัติงานที่ไม่ใช่หัวหน้าคุณจะต้องบันทึกโมเดลลงในไดเร็กทอรีชั่วคราวและสำหรับหัวหน้าคุณจะต้องบันทึกลงในไดเร็กทอรีโมเดลที่ให้มา ไดเร็กทอรีชั่วคราวเกี่ยวกับผู้ปฏิบัติงานต้องไม่ซ้ำกันเพื่อป้องกันข้อผิดพลาดที่เกิดจากคนงานหลายคนพยายามเขียนไปยังตำแหน่งเดียวกัน โมเดลที่บันทึกไว้ในไดเร็กทอรีทั้งหมดจะเหมือนกันและโดยทั่วไปแล้วจะต้องอ้างอิงเฉพาะโมเดลที่หัวหน้าบันทึกไว้เพื่อเรียกคืนหรือให้บริการ คุณควรมีตรรกะในการล้างข้อมูลที่ลบไดเรกทอรีชั่วคราวที่คนงานสร้างขึ้นเมื่อการฝึกอบรมของคุณเสร็จสิ้น

เหตุผลที่คุณต้องประหยัดหัวหน้าและคนงานในเวลาเดียวกันเพราะคุณอาจรวมตัวแปรระหว่างการตรวจสอบซึ่งต้องให้ทั้งหัวหน้าและคนงานเข้าร่วมในโปรโตคอลการสื่อสารแบบ allreduce ในทางกลับกันการให้หัวหน้าและคนงานบันทึกลงในไดเร็กทอรีโมเดลเดียวกันจะทำให้เกิดข้อผิดพลาดเนื่องจากการทะเลาะกัน

ด้วย MultiWorkerMirroredStrategy โปรแกรมจะทำงานกับผู้ปฏิบัติงานทุกคนและเพื่อให้ทราบว่าผู้ปฏิบัติงานปัจจุบันเป็นหัวหน้าหรือไม่จะใช้ประโยชน์จากวัตถุตัวแก้ไขคลัสเตอร์ที่มีแอตทริบิวต์ task_type และ task_id task_type จะบอกคุณว่างานปัจจุบันคืออะไร (เช่น 'คนงาน') และ task_id จะบอกคุณถึงตัวระบุของคนงาน ผู้ปฏิบัติงานที่มี id 0 ถูกกำหนดให้เป็นหัวหน้าคนงาน

ในข้อมูลโค้ดด้านล่าง write_filepath ระบุเส้นทางไฟล์ที่จะเขียนซึ่งขึ้นอยู่กับ ID ผู้ปฏิบัติงาน ในกรณีของหัวหน้า (ผู้ปฏิบัติงานที่มี id 0) จะเขียนไปยังเส้นทางไฟล์ต้นฉบับ สำหรับผู้อื่นจะสร้างไดเร็กทอรีชั่วคราว (พร้อม id ในเส้นทางไดเร็กทอรี) เพื่อเขียน:

model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # If `task_type` is None, this may be operating as single worker, which works
  # effectively as chief.
  return task_type is None or task_type == 'chief' or (
            task_type == 'worker' and task_id == 0)

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)

ด้วยเหตุนี้คุณก็พร้อมที่จะบันทึก:

multi_worker_model.save(write_model_path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/keras-model/assets

INFO:tensorflow:Assets written to: /tmp/keras-model/assets

ตามที่อธิบายไว้ข้างต้นควรโหลดโมเดลในภายหลังจากหัวหน้าพา ธ ที่บันทึกไว้เท่านั้นดังนั้นเรามาลบสิ่งชั่วคราวที่ผู้ปฏิบัติงานที่ไม่ใช่หัวหน้าบันทึกไว้:

if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))

ตอนนี้เมื่อถึงเวลาโหลดเรามาใช้ tf.keras.models.load_model API ที่สะดวกและทำงานต่อไป ที่นี่สมมติว่าใช้คนงานคนเดียวในการโหลดและฝึกอบรมต่อไปซึ่งในกรณีนี้คุณจะไม่เรียก tf.keras.models.load_model ภายใน strategy.scope() อื่น strategy.scope()

loaded_model = tf.keras.models.load_model(model_path)

# Now that the model is restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 13ms/step - loss: 2.2929 - accuracy: 0.0000e+00
Epoch 2/2
20/20 [==============================] - 0s 13ms/step - loss: 2.2736 - accuracy: 0.0016

<tensorflow.python.keras.callbacks.History at 0x7fded172d1d0>

การประหยัดและการกู้คืนจุดตรวจ

ในทางกลับกันการตรวจจุดตรวจช่วยให้คุณประหยัดน้ำหนักของโมเดลและเรียกคืนได้โดยไม่ต้องบันทึกทั้งโมเดล ที่นี่คุณจะสร้าง tf.train.Checkpoint หนึ่ง tf.train.Checkpoint ที่ติดตามโมเดลซึ่งจัดการโดย tf.train.CheckpointManager เพื่อให้เฉพาะจุดตรวจสอบล่าสุดเท่านั้นที่จะถูกเก็บรักษาไว้

checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

เมื่อตั้งค่า CheckpointManager แล้วคุณก็พร้อมที่จะบันทึกและลบจุดตรวจที่ไม่ใช่หัวหน้าคนงานที่บันทึกไว้

checkpoint_manager.save()
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)

ตอนนี้เมื่อคุณต้องการกู้คืนคุณสามารถค้นหาจุดตรวจล่าสุดที่บันทึกไว้โดยใช้ฟังก์ชัน tf.train.latest_checkpoint สะดวก หลังจากคืนค่าด่านแล้วคุณสามารถฝึกต่อได้

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

20/20 [==============================] - 0s 13ms/step - loss: 2.2900 - accuracy: 0.1656
Epoch 2/2
20/20 [==============================] - 0s 13ms/step - loss: 2.2727 - accuracy: 0.1656

<tensorflow.python.keras.callbacks.History at 0x7fded1677dd8>

BackupAndRestore โทรกลับ

การเรียกกลับ BackupAndRestore มีฟังก์ชันการยอมรับข้อผิดพลาดโดยการสำรองข้อมูลโมเดลและหมายเลขยุคปัจจุบันในไฟล์จุดตรวจสอบชั่วคราวภายใต้อาร์กิวเมนต์ backup_dir ไปยัง BackupAndRestore สิ่งนี้จะเสร็จสิ้นในตอนท้ายของแต่ละยุค

เมื่องานหยุดชะงักและเริ่มต้นใหม่การเรียกกลับจะเรียกคืนจุดตรวจสุดท้ายและการฝึกอบรมจะดำเนินต่อไปตั้งแต่จุดเริ่มต้นของยุคที่ถูกขัดจังหวะ การฝึกอบรมบางส่วนที่ทำไปแล้วในยุคที่ยังไม่เสร็จสิ้นก่อนที่จะหยุดชะงักจะถูกโยนทิ้งไปเพื่อไม่ให้ส่งผลต่อสถานะของโมเดลขั้นสุดท้าย

ในการใช้งานให้ระบุอินสแตนซ์ของ tf.keras.callbacks.experimental.BackupAndRestore ที่การ tf.keras.Model.fit()

ด้วย MultiWorkerMirroredStrategy หากผู้ปฏิบัติงานถูกขัดจังหวะคลัสเตอร์ทั้งหมดจะหยุดชั่วคราวจนกว่าผู้ปฏิบัติงานที่ถูกขัดจังหวะจะถูกรีสตาร์ท ผู้ปฏิบัติงานคนอื่นจะเริ่มการทำงานใหม่เช่นกันและผู้ปฏิบัติงานที่ถูกขัดจังหวะจะเข้าร่วมคลัสเตอร์อีกครั้ง จากนั้นผู้ปฏิบัติงานทุกคนจะอ่านไฟล์จุดตรวจที่บันทึกไว้ก่อนหน้านี้และรับสถานะเดิมซึ่งจะช่วยให้คลัสเตอร์กลับมาซิงค์ได้ จากนั้นการฝึกอบรมยังคงดำเนินต่อไป

BackupAndRestore เรียกกลับ BackupAndRestore ใช้ CheckpointManager เพื่อบันทึกและกู้คืนสถานะการฝึกอบรมซึ่งสร้างไฟล์ที่เรียกว่าจุดตรวจที่ติดตามจุดตรวจที่มีอยู่ร่วมกับด่านล่าสุด ด้วยเหตุนี้จึงไม่ควรใช้ backup_dir ซ้ำเพื่อเก็บจุดตรวจอื่น ๆ เพื่อหลีกเลี่ยงการชนชื่อ

ขณะนี้การเรียกกลับ BackupAndRestore สนับสนุนผู้ปฏิบัติงานคนเดียวโดยไม่มีกลยุทธ์ MirroredStrategy และผู้ปฏิบัติงานหลายคนด้วย MultiWorkerMirroredStrategy ด้านล่างนี้เป็นสองตัวอย่างสำหรับทั้งการฝึกอบรมผู้ปฏิบัติงานหลายคนและการฝึกอบรมผู้ปฏิบัติงานคนเดียว

# Multi-worker training with MultiWorkerMirroredStrategy.

callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)
Epoch 1/3
70/70 [==============================] - 1s 14ms/step - loss: 2.2724 - accuracy: 0.2118
Epoch 2/3
70/70 [==============================] - 1s 14ms/step - loss: 2.2001 - accuracy: 0.4250
Epoch 3/3
70/70 [==============================] - 1s 14ms/step - loss: 2.1129 - accuracy: 0.5683

<tensorflow.python.keras.callbacks.History at 0x7fded0cdea20>

หากคุณตรวจสอบไดเร็กทอรีของ backup_dir คุณระบุใน BackupAndRestore คุณอาจสังเกตเห็นไฟล์จุดตรวจที่สร้างขึ้นชั่วคราว ไฟล์เหล่านั้นจำเป็นสำหรับการกู้คืนอินสแตนซ์ที่หายไปก่อนหน้านี้และไลบรารีจะถูกลบออกเมื่อสิ้นสุด tf.keras.Model.fit() เมื่อออกจากการฝึกอบรมของคุณสำเร็จ

ดูสิ่งนี้ด้วย

  1. การฝึกอบรมแบบกระจายใน คู่มือ TensorFlow จะให้ภาพรวมของกลยุทธ์การกระจายที่มีอยู่
  2. แบบจำลองอย่างเป็นทางการ ซึ่งหลายแบบสามารถกำหนดค่าให้เรียกใช้กลยุทธ์การกระจายได้หลายแบบ
  3. ส่วนประสิทธิภาพ ในคำแนะนำจะให้ข้อมูลเกี่ยวกับกลยุทธ์และ เครื่องมือ อื่น ๆ ที่คุณสามารถใช้เพื่อเพิ่มประสิทธิภาพของโมเดล TensorFlow ของคุณ