Trang này được dịch bởi Cloud Translation API.
Switch to English

Đào tạo nhiều công nhân với Keras

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải vở

Tổng quat

Hướng dẫn này thể hiện đào tạo phân tán nhiều công nhân với mô hình tf.distribute.Strategy bằng cách sử dụng API tf.distribute.Strategy , cụ thể là tf.distribute.experimental.MultiWorkerMirroredStrategy . Với sự trợ giúp của chiến lược này, một mô hình Keras được thiết kế để chạy trên một công nhân đơn lẻ có thể hoạt động liền mạch trên nhiều công nhân với sự thay đổi mã tối thiểu.

Đào tạo phân tán trong hướng dẫn TensorFlow có sẵn để biết tổng quan về các chiến lược phân phối mà TensorFlow hỗ trợ cho những ai quan tâm đến sự hiểu biết sâu sắc hơn về API tf.distribute.Strategy .

Thiết lập

Đầu tiên, thiết lập TensorFlow và nhập khẩu cần thiết.

 import os
import tensorflow as tf
import numpy as np
 

Chuẩn bị dữ liệu

Bây giờ, hãy chuẩn bị bộ dữ liệu MNIST. Bộ dữ liệu MNIST bao gồm 60.000 ví dụ đào tạo và 10.000 ví dụ thử nghiệm về các chữ số viết tay 0 mộc9, được định dạng là hình ảnh đơn sắc 28x28 pixel. Trong ví dụ này, chúng tôi sẽ lấy phần đào tạo của các bộ dữ liệu để chứng minh.

 def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # We need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset
 

Xây dựng mô hình Keras

Ở đây, chúng tôi sử dụng API tf.keras.Sequential để xây dựng và biên dịch một mạng nơ ron tích chập đơn giản Mô hình Keras để đào tạo với bộ dữ liệu MNIST của chúng tôi.

 def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model
 

Trước tiên, hãy thử đào tạo mô hình cho một số lượng nhỏ kỷ nguyên và quan sát kết quả trong một công nhân duy nhất để đảm bảo mọi thứ hoạt động chính xác. Bạn sẽ thấy sự sụt giảm và độ chính xác đạt tới 1.0 khi kỷ nguyên tiến bộ.

 per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
 
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/3
70/70 [==============================] - 0s 2ms/step - loss: 2.2701 - accuracy: 0.2451
Epoch 2/3
70/70 [==============================] - 0s 2ms/step - loss: 2.1827 - accuracy: 0.4777
Epoch 3/3
70/70 [==============================] - 0s 2ms/step - loss: 2.0865 - accuracy: 0.5955

<tensorflow.python.keras.callbacks.History at 0x7fc59381ac50>

Cấu hình nhiều công nhân

Bây giờ hãy bước vào thế giới đào tạo nhiều công nhân. Trong TensorFlow, biến môi trường TF_CONFIG được yêu cầu để đào tạo trên nhiều máy, mỗi máy có thể có một vai trò khác nhau. TF_CONFIG là một chuỗi JSON được sử dụng để chỉ định cấu hình cụm trên mỗi công nhân là một phần của cụm.

Có hai thành phần của TF_CONFIG : clustertask . cluster cung cấp thông tin về cụm đào tạo, đó là một lệnh bao gồm các loại công việc khác nhau như worker . Trong đào tạo nhiều công nhân với MultiWorkerMirroredStrategy , thường có một worker chịu trách nhiệm hơn một chút như lưu điểm kiểm tra và viết tệp tóm tắt cho TensorBoard ngoài những gì một worker bình thường làm. Công nhân đó được gọi là công nhân chief , và theo thông lệ, workerindex 0 được bổ nhiệm làm worker (thực tế đây là cách tf.distribute.Strategy được thực hiện). task cung cấp thông tin của nhiệm vụ hiện tại. cluster thành phần đầu tiên giống nhau cho tất cả các công nhân và task thành phần thứ hai là khác nhau đối với mỗi công nhân và chỉ định typeindex của công nhân đó.

Trong ví dụ này, chúng tôi đặt type nhiệm vụ thành "worker"index nhiệm vụ thành 0 . Điều này có nghĩa là cỗ máy có cài đặt như vậy là công nhân đầu tiên, sẽ được bổ nhiệm làm công nhân trưởng và làm nhiều công việc hơn các công nhân khác. Lưu ý rằng các máy khác cũng cần phải có biến môi trường TF_CONFIG , và nó phải có cùng một cluster , nhưng type nhiệm vụ hoặc index tác vụ khác nhau tùy thuộc vào vai trò của các máy đó.

Đối với mục đích minh họa, hướng dẫn này cho thấy cách người ta có thể đặt TF_CONFIG với 2 công nhân trên localhost . Trong thực tế, người dùng sẽ tạo nhiều công nhân trên các địa chỉ / cổng IP bên ngoài và đặt TF_CONFIG cho mỗi công nhân một cách thích hợp.

 os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})
 

Lưu ý rằng mặc dù tỷ lệ học tập được cố định trong ví dụ này, nhưng nhìn chung có thể cần phải điều chỉnh tỷ lệ học tập dựa trên quy mô lô toàn cầu.

Chọn chiến lược đúng

Trong TensorFlow, đào tạo phân tán bao gồm đào tạo đồng bộ, trong đó các bước đào tạo được đồng bộ hóa giữa các công nhân và bản sao, và đào tạo không đồng bộ, trong đó các bước đào tạo không được đồng bộ hóa nghiêm ngặt.

MultiWorkerMirroredStrategy , là chiến lược được đề xuất cho đào tạo nhiều công nhân đồng bộ, sẽ được trình bày trong hướng dẫn này. Để huấn luyện mô hình, sử dụng một thể hiện của tf.distribute.experimental.MultiWorkerMirroredStrategy . MultiWorkerMirroredStrategy tạo các bản sao của tất cả các biến trong các lớp của mô hình trên mỗi thiết bị trên tất cả các công nhân. Nó sử dụng CollectiveOps , một TensorFlow op để liên lạc tập thể, để tổng hợp các gradient và giữ các biến đồng bộ. Hướng dẫn tf.distribute.Strategy có nhiều chi tiết hơn về chiến lược này.

 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

MultiWorkerMirroredStrategy cung cấp nhiều triển khai thông qua tham số CollectiveCommunication thông. RING thực hiện các tập thể dựa trên vòng sử dụng gRPC làm lớp giao tiếp giữa các máy chủ. NCCL sử dụng NCCL của Nvidia để triển khai tập thể. AUTO trì hoãn sự lựa chọn cho thời gian chạy. Sự lựa chọn tốt nhất để thực hiện tập thể phụ thuộc vào số lượng và loại GPU và kết nối mạng trong cụm.

Huấn luyện mô hình với MultiWorkerMirroredStrargety

Với việc tích hợp API tf.distribute.Strategy vào tf.keras , thay đổi duy nhất bạn sẽ thực hiện để phân phối đào tạo cho nhiều công nhân là bao quanh việc xây dựng mô hình và gọi model.compile() bên trong model.compile() strategy.scope() . Phạm vi của chiến lược phân phối chỉ ra cách thức và nơi các biến được tạo và trong trường hợp MultiWorkerMirroredStrategy , các biến được tạo là MirroredVariable s và chúng được sao chép trên mỗi công nhân.

 num_workers = 4

# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = build_and_compile_cnn_model()

# Keras' `model.fit()` trains the model with specified number of epochs and
# number of steps per epoch. Note that the numbers here are for demonstration
# purposes only and may not sufficiently produce a model with good quality.
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
 
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
70/70 [==============================] - 0s 3ms/step - loss: 2.2682 - accuracy: 0.2265
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1714 - accuracy: 0.4954
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.0638 - accuracy: 0.6232

<tensorflow.python.keras.callbacks.History at 0x7fc5f4f062e8>

Kho dữ liệu và kích thước lô

Trong đào tạo nhiều công nhân với MultiWorkerMirroredStrategy , việc sắp xếp bộ dữ liệu là cần thiết để đảm bảo sự hội tụ và hiệu suất. Tuy nhiên, lưu ý rằng trong đoạn mã trên, các bộ dữ liệu được truyền trực tiếp đến model.fit() mà không cần phải phân đoạn; điều này là do API tf.distribute.Strategy tự động xử lý việc sắp xếp bộ dữ liệu. Nó phân đoạn dữ liệu ở cấp độ tệp có thể tạo ra các phân đoạn lệch. Trong các trường hợp cực đoan khi chỉ có một tệp, chỉ có phân đoạn đầu tiên (tức là công nhân) sẽ nhận được dữ liệu đào tạo hoặc đánh giá và kết quả là tất cả các công nhân sẽ gặp lỗi.

Nếu bạn thích shending thủ công cho đào tạo của bạn, shending tự động có thể được tắt thông qua tf.data.experimental.DistributeOptions api. Cụ thể,

 options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
dataset_no_auto_shard = multi_worker_dataset.with_options(options)
 

Một điều cần chú ý là kích thước lô cho các datasets . Trong đoạn mã trên, chúng tôi sử dụng global_batch_size = per_worker_batch_size * num_workers , đó là num_workers lần lớn như trường hợp nó là dành cho người lao động duy nhất, bởi vì hiệu quả cho mỗi kích thước nhân hàng loạt là kích thước lô hàng toàn cầu (tham số truyền vào trong tf.data.Dataset.batch() ) chia cho số lượng công nhân và với thay đổi này, chúng tôi sẽ giữ kích thước lô cho mỗi công nhân như trước đây.

Đánh giá

Nếu bạn vượt qua validation_data vào model.fit , nó sẽ luân phiên giữa đào tạo và đánh giá đối với từng thời đại. Việc đánh giá lấy validation_data được phân phối trên cùng một nhóm công nhân và kết quả đánh giá được tổng hợp và có sẵn cho tất cả các công nhân. Tương tự như đào tạo, bộ dữ liệu xác nhận được tự động phân chia ở cấp độ tệp. Bạn cần đặt kích thước lô toàn cầu trong tập dữ liệu xác thực và đặt validation_steps . Một bộ dữ liệu lặp đi lặp lại cũng được khuyến nghị để đánh giá.

Ngoài ra, bạn cũng có thể tạo một tác vụ khác định kỳ đọc điểm kiểm tra và chạy đánh giá. Đây là những gì Công cụ ước tính làm. Nhưng đây không phải là một cách được khuyến nghị để thực hiện đánh giá và do đó chi tiết của nó bị bỏ qua.

Sự dự đoán

Hiện tại model.predict không hoạt động với MultiWorkerMirroredStrategy.

Hiệu suất

Bây giờ bạn có một mô hình Keras được thiết lập để chạy trong nhiều công nhân với MultiWorkerMirroredStrategy . Bạn có thể thử các kỹ thuật sau đây để điều chỉnh hiệu suất đào tạo nhiều nhân viên với MultiWorkerMirroredStrategy .

  • MultiWorkerMirroredStrategy cung cấp nhiều triển khai truyền thông tập thể . RING thực hiện các tập thể dựa trên vòng sử dụng gRPC làm lớp giao tiếp giữa các máy chủ. NCCL sử dụng NCCL của Nvidia để triển khai tập thể. AUTO trì hoãn sự lựa chọn cho thời gian chạy. Sự lựa chọn tốt nhất để thực hiện tập thể phụ thuộc vào số lượng và loại GPU và kết nối mạng trong cụm. Để ghi đè lựa chọn tự động, chỉ định giá trị hợp lệ cho tham số communication của MultiWorkerMirroredStrategy của MultiWorkerMirroredStrategy , ví dụ: communication=tf.distribute.experimental.CollectiveCommunication.NCCL .
  • Truyền các biến thành tf.float nếu có thể. Mô hình ResNet chính thức bao gồm một ví dụ về cách thực hiện việc này.

Chịu lỗi

Trong đào tạo đồng bộ, cụm sẽ thất bại nếu một trong các công nhân thất bại và không có cơ chế phục hồi thất bại. Sử dụng Keras với tf.distribute.Strategy đi kèm với lợi thế về khả năng chịu lỗi trong trường hợp công nhân chết hoặc không ổn định. Chúng tôi thực hiện điều này bằng cách duy trì trạng thái đào tạo trong hệ thống tệp phân tán mà bạn chọn, sao cho khi khởi động lại trường hợp trước đó bị lỗi hoặc bị cấm, trạng thái đào tạo được phục hồi.

Vì tất cả các công nhân được giữ đồng bộ về mặt thời gian đào tạo và các bước, nên các công nhân khác sẽ cần đợi công nhân thất bại hoặc được ưu tiên khởi động lại để tiếp tục.

Gọi lại ModelCheckpoint

ModelCheckpoint lại ModelCheckpoint không còn cung cấp chức năng chịu lỗi, vui lòng sử dụng gọi lại BackupAndRestore thay thế.

Cuộc gọi lại ModelCheckpoint vẫn có thể được sử dụng để lưu điểm kiểm tra. Nhưng với điều này, nếu việc đào tạo bị gián đoạn hoặc kết thúc thành công, để tiếp tục đào tạo từ điểm kiểm tra, người dùng có trách nhiệm tải mô hình theo cách thủ công. Người dùng tùy chọn có thể chọn lưu và khôi phục mô hình / trọng lượng bên ngoài cuộc gọi lại ModelCheckpoint .

Mô hình lưu và tải

Để lưu mô hình của bạn bằng model.save hoặc tf.saved_model.save , đích để lưu cần phải khác nhau đối với mỗi công nhân. Trên các công nhân không phải là giám đốc, bạn sẽ cần lưu mô hình vào một thư mục tạm thời, và trên giám đốc, bạn sẽ cần lưu vào thư mục mô hình được cung cấp. Các thư mục tạm thời về công nhân cần phải là duy nhất để ngăn ngừa lỗi do nhiều công nhân cố gắng ghi vào cùng một vị trí. Mô hình được lưu trong tất cả các thư mục là giống hệt nhau và thông thường chỉ có mô hình được lưu bởi trưởng nên được tham chiếu để khôi phục hoặc phục vụ. Chúng tôi khuyên bạn nên có một số logic dọn dẹp xóa các thư mục tạm thời được tạo bởi các công nhân sau khi quá trình đào tạo của bạn hoàn tất.

Lý do bạn cần tiết kiệm cho giám đốc và công nhân cùng một lúc là bởi vì bạn có thể đang tổng hợp các biến trong quá trình kiểm tra, điều này đòi hỏi cả trưởng và công nhân phải tham gia vào giao thức truyền thông allreduce. Mặt khác, để cho trưởng và công nhân lưu vào cùng một thư mục mô hình sẽ dẫn đến lỗi do sự tranh chấp.

Với MultiWorkerMirroredStrategy , chương trình được chạy trên mọi công nhân và để biết liệu công nhân hiện tại có phải là trưởng hay không, chúng tôi tận dụng đối tượng trình phân giải cụm có thuộc tính task_typetask_id . task_type cho bạn biết công việc hiện tại là gì (ví dụ: 'worker') và task_id cho bạn biết định danh của worker. Công nhân có id 0 được chỉ định là công nhân trưởng.

Trong đoạn mã dưới đây, write_filepath cung cấp đường dẫn tệp để ghi, tùy thuộc vào id worker. Trong trường hợp trưởng (worker với id 0), nó ghi vào đường dẫn tệp gốc; đối với những người khác, nó tạo một thư mục tạm thời (có id trong đường dẫn thư mục) để ghi vào:

 model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # If `task_type` is None, this may be operating as single worker, which works 
  # effectively as chief.
  return task_type is None or task_type == 'chief' or (
            task_type == 'worker' and task_id == 0)

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)
 

Với điều đó, bây giờ bạn đã sẵn sàng để lưu:

 multi_worker_model.save(write_model_path)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras-model/assets

Như chúng tôi đã mô tả ở trên, sau này trên mô hình chỉ nên được tải từ đường dẫn chính được lưu vào, vì vậy chúng ta hãy xóa những cái tạm thời mà các công nhân không phải là trưởng đã lưu:

 if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))
 

Bây giờ, khi đến lúc tải, hãy sử dụng API tf.keras.models.load_model tiện lợi và tiếp tục với công việc tiếp theo. Ở đây, chúng tôi giả sử chỉ sử dụng một công nhân duy nhất để tải và tiếp tục đào tạo, trong trường hợp đó bạn không gọi tf.keras.models.load_model trong một strategy.scope() tf.keras.models.load_model strategy.scope() .

 loaded_model = tf.keras.models.load_model(model_path)

# Now that we have the model restored, and can continue with the training.
loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9825 - accuracy: 0.1102
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9367 - accuracy: 0.1117

<tensorflow.python.keras.callbacks.History at 0x7fc5f4b0d8d0>

Lưu điểm kiểm tra và khôi phục

Mặt khác, điểm kiểm tra cho phép bạn lưu trọng lượng của mô hình và khôi phục lại chúng mà không phải lưu toàn bộ mô hình. Tại đây, bạn sẽ tạo một tf.train.Checkpoint theo dõi mô hình, được quản lý bởi tf.train.CheckpointManager để chỉ bảo tồn điểm kiểm tra mới nhất.

 checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
  checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
 

Sau khi CheckpointManager được thiết lập, giờ bạn đã sẵn sàng để lưu và xóa các điểm kiểm tra không phải là nhân viên chính được lưu.

 checkpoint_manager.save()
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)
 

Bây giờ, khi bạn cần khôi phục, bạn có thể tìm thấy điểm kiểm tra mới nhất được lưu bằng hàm tf.train.latest_checkpoint thuận tiện. Sau khi khôi phục điểm kiểm tra, bạn có thể tiếp tục đào tạo.

 latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9841 - accuracy: 0.6561
Epoch 2/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9445 - accuracy: 0.6805

<tensorflow.python.keras.callbacks.History at 0x7fc5f49d9d30>

Gọi lại BackupAndRestore

Gọi lại BackupAndRestore cung cấp chức năng chịu lỗi, bằng cách sao lưu mô hình và số epoch hiện tại trong tệp điểm kiểm tra tạm thời theo đối số backup_dir cho BackupAndRestore . Điều này được thực hiện vào cuối mỗi kỷ nguyên.

Khi các công việc bị gián đoạn và khởi động lại, cuộc gọi lại sẽ khôi phục điểm kiểm tra cuối cùng và việc đào tạo tiếp tục từ đầu kỷ nguyên bị gián đoạn. Bất kỳ đào tạo một phần nào đã được thực hiện trong kỷ nguyên chưa hoàn thành trước khi gián đoạn sẽ bị loại bỏ, do đó nó không ảnh hưởng đến trạng thái mô hình cuối cùng.

Để sử dụng nó, hãy cung cấp một ví dụ của tf.keras.callbacks.experimental.BackupAndRestore tại cuộc gọi tf.keras.Model.fit() .

Với MultiWorkerMirroredStrargety, nếu một công nhân bị gián đoạn, toàn bộ cụm tạm dừng cho đến khi công nhân bị gián đoạn được khởi động lại. Các công nhân khác cũng sẽ khởi động lại và công nhân bị gián đoạn tham gia lại cụm. Sau đó, mọi nhân viên đọc tệp điểm kiểm tra đã được lưu trước đó và chọn trạng thái cũ của nó, do đó cho phép cụm lấy lại đồng bộ hóa. Sau đó, việc đào tạo tiếp tục.

BackupAndRestore lại BackupAndRestore sử dụng CheckpointManager để lưu và khôi phục trạng thái đào tạo, tạo ra một tệp gọi là điểm kiểm tra theo dõi các điểm kiểm tra hiện có cùng với điểm kiểm tra mới nhất. Vì lý do này, backup_dir không nên được sử dụng lại để lưu trữ các điểm kiểm tra khác để tránh xung đột tên.

Hiện tại, gọi lại BackupAndRestore hỗ trợ công nhân đơn lẻ không có chiến lược, MirroredStrargety và đa công nhân với MultiWorkerMirroredStrargety. Dưới đây là hai ví dụ cho cả đào tạo nhiều công nhân và đào tạo công nhân đơn lẻ.

 # Multi-worker training with MultiWorkerMirroredStrategy.

callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)
 
Epoch 1/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2837 - accuracy: 0.1836
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2131 - accuracy: 0.4091
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1310 - accuracy: 0.5485

<tensorflow.python.keras.callbacks.History at 0x7fc5f49a3080>

Nếu bạn kiểm tra thư mục của backup_dir mà bạn đã chỉ định trong BackupAndRestore , bạn có thể nhận thấy một số tệp điểm kiểm tra được tạo tạm thời. Những tệp đó là cần thiết để khôi phục các trường hợp bị mất trước đó và chúng sẽ bị thư viện xóa vào cuối tf.keras.Model.fit() sau khi thoát khỏi khóa đào tạo của bạn.

Xem thêm

  1. Đào tạo phân tán trong hướng dẫn TensorFlow cung cấp tổng quan về các chiến lược phân phối có sẵn.
  2. Các mô hình chính thức , nhiều trong số đó có thể được cấu hình để chạy nhiều chiến lược phân phối.
  3. Phần Hiệu suất trong hướng dẫn cung cấp thông tin về các chiến lược và công cụ khác mà bạn có thể sử dụng để tối ưu hóa hiệu suất của các mô hình TensorFlow của mình.