利用 Keras 来训练多工作器(worker)

在 tensorflow.google.cn 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载此 notebook

概述

本教程使用 tf.distribute.Strategy API 演示了使用 Keras 模型的多工作器(worker)分布式培训。借助专为多工作器(worker)训练而设计的策略,设计在单一工作器(worker)上运行的 Keras 模型可以在最少的代码更改的情况下无缝地处理多个工作器。

TensorFlow 中的分布式培训指南可用于概述TensorFlow支持的分布式策略,并想要更深入理解tf.distribute.Strategy API 感兴趣的人。

配置

首先,设置 TensorFlow 和必要的导入。

!pip install tf-nightly
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

准备数据集

现在,让我们从 TensorFlow 数据集 中准备MNIST数据集。 MNIST 数据集 包括60,000个训练样本和10,000个手写数字0-9的测试示例,格式为28x28像素单色图像。

BUFFER_SIZE = 10000
BATCH_SIZE = 64

def make_datasets_unbatched():
  # 将 MNIST 数据从 (0, 255] 缩放到 (0., 1.]
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  datasets, info = tfds.load(name='mnist',
                            with_info=True,
                            as_supervised=True)

  return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)

train_datasets = make_datasets_unbatched().batch(BATCH_SIZE)

构建 Keras 模型

在这里,我们使用tf.keras.Sequential API来构建和编译一个简单的卷积神经网络 Keras 模型,用我们的 MNIST 数据集进行训练。

注意:有关构建 Keras 模型的详细训练说明,请参阅TensorFlow Keras 指南

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

让我们首先尝试用少量的 epoch 来训练模型,并在单个工作器(worker)中观察结果,以确保一切正常。 随着训练的迭代,您应该会看到损失(loss)下降和准确度(accuracy)接近1.0。

single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
Epoch 1/3
5/5 [==============================] - 2s 5ms/step - loss: 2.3034 - accuracy: 0.0875
Epoch 2/3
5/5 [==============================] - 0s 4ms/step - loss: 2.3089 - accuracy: 0.0844
Epoch 3/3
5/5 [==============================] - 0s 4ms/step - loss: 2.3126 - accuracy: 0.0875
2022-06-03 20:07:36.608025: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-06-03 20:07:36.610575: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
<keras.callbacks.History at 0x7fc29c18fb20>

多工作器(worker)配置

现在让我们进入多工作器(worker)训练的世界。在 TensorFlow 中,需要 TF_CONFIG 环境变量来训练多台机器,每台机器可能具有不同的角色。 TF_CONFIG用于指定作为集群一部分的每个 worker 的集群配置。

TF_CONFIG 有两个组件:clustertaskcluster 提供有关训练集群的信息,这是一个由不同类型的工作组成的字典,例如 worker 。在多工作器(worker)培训中,除了常规的“工作器”之外,通常还有一个“工人”承担更多责任,比如保存检查点和为 TensorBoard 编写摘要文件。这样的工作器(worker)被称为“主要”工作者,习惯上workerindex 0被指定为主要的 worker(事实上这就是tf.distribute.Strategy的实现方式)。 另一方面,task 提供当前任务的信息。

在这个例子中,我们将任务 type 设置为 "worker" 并将任务 index 设置为 0 。这意味着具有这种设置的机器是第一个工作器,它将被指定为主要工作器并且要比其他工作器做更多的工作。请注意,其他机器也需要设置 TF_CONFIG 环境变量,它应该具有相同的 cluster 字典,但是不同的任务typeindex 取决于这些机器的角色。

为了便于说明,本教程展示了如何在 localhost 上设置一个带有2个工作器的TF_CONFIG。 实际上,用户会在外部IP地址/端口上创建多个工作器,并在每个工作器上适当地设置TF_CONFIG

警告:不要在 Colab 中执行以下代码。TensorFlow 的运行时将尝试在指定的IP地址和端口创建 gRPC 服务器,这可能会失败。

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

注意,虽然在该示例中学习速率是固定的,但是通常可能需要基于全局批量大小来调整学习速率。

选择正确的策略

在 TensorFlow 中,分布式训练包括同步训练(其中训练步骤跨工作器和副本同步)、异步训练(训练步骤未严格同步)。

MultiWorkerMirroredStrategy 是同步多工作器训练的推荐策略,将在本指南中进行演示。

要训练模型,请使用 tf.distribute.experimental.MultiWorkerMirroredStrategy 的实例。 MultiWorkerMirroredStrategy 在所有工作器的每台设备上创建模型层中所有变量的副本。 它使用 CollectiveOps ,一个用于集体通信的 TensorFlow 操作,来聚合梯度并使变量保持同步。 tf.distribute.Strategy指南有关于此策略的更多详细信息。

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46979/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46979/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO

注意:解析 TF_CONFIG 并且在调用 MultiWorkerMirroredStrategy.init() 时启动 TensorFlow 的 GRPC 服务器,因此必须在创建tf.distribute.Strategy实例之前设置 TF_CONFIG 环境变量。

MultiWorkerMirroredStrategy 通过CollectiveCommunication参数提供多个实现。RING 使用 gRPC 作为跨主机通信层实现基于环的集合。NCCL 使用Nvidia 的 NCCL来实现集体。 AUTO 将选择推迟到运行时。 集体实现的最佳选择取决于GPU的数量和种类以及群集中的网络互连。

使用 MultiWorkerMirroredStrategy 训练模型

通过将 tf.distribute.Strategy API集成到 tf.keras 中,将训练分发给多人的唯一更改就是将模型进行构建和 model.compile() 调用封装在 strategy.scope() 内部。 分发策略的范围决定了如何创建变量以及在何处创建变量,对于 MultiWorkerMirroredStrategy 而言,创建的变量为 MirroredVariable ,并且将它们复制到每个工作器上。

注意:在此Colab中,以下代码可以按预期结果运行,但是由于未设置TF_CONFIG,因此这实际上是单机训练。 在您自己的示例中设置了 TF_CONFIG 后,您应该期望在多台机器上进行培训可以提高速度。

NUM_WORKERS = 2
# 由于 `tf.data.Dataset.batch` 需要全局的批处理大小,
# 因此此处的批处理大小按工作器数量增加。
# 以前我们使用64,现在变成128。
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS

# 创建数据集需要在 MultiWorkerMirroredStrategy 对象
# 实例化后。
train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE)
with strategy.scope():
  # 模型的建立/编译需要在 `strategy.scope()` 内部。
  multi_worker_model = build_and_compile_cnn_model()

# Keras 的 `model.fit()` 以特定的时期数和每时期的步数训练模型。
# 注意此处的数量仅用于演示目的,并不足以产生高质量的模型。
multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
2022-06-03 20:07:38.409119: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/3
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
5/5 [==============================] - 8s 15ms/step - loss: 2.3121 - accuracy: 0.0656
Epoch 2/3
5/5 [==============================] - 0s 14ms/step - loss: 2.3034 - accuracy: 0.0922
Epoch 3/3
5/5 [==============================] - 0s 14ms/step - loss: 2.3047 - accuracy: 0.1047
2022-06-03 20:07:46.666346: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-06-03 20:07:46.669124: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
<keras.callbacks.History at 0x7fc2842d79d0>

数据集分片和批(batch)大小

在多工作器训练中,需要将数据分片为多个部分,以确保融合和性能。 但是,请注意,在上面的代码片段中,数据集直接发送到model.fit(),而无需分片; 这是因为tf.distribute.Strategy API在多工作器训练中会自动处理数据集分片。

如果您喜欢手动分片进行训练,则可以通过tf.data.experimental.DistributeOptions API关闭自动分片。

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
train_datasets_no_auto_shard = train_datasets.with_options(options)

要注意的另一件事是 datasets 的批处理大小。 在上面的代码片段中,我们使用 GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS ,这是单个工作器的大小的 NUM_WORKERS 倍,因为每个工作器的有效批量大小是全局批量大小(参数从 tf.data.Dataset.batch() 传入)除以工作器的数量,通过此更改,我们使每个工作器的批处理大小与以前相同。

性能

现在,您已经有了一个Keras模型,该模型全部通过 MultiWorkerMirroredStrategy 运行在多个工作器中。 您可以尝试以下技术来调整多工作器训练的效果。

  • MultiWorkerMirroredStrategy 提供了多个[集体通信实现]collective communication implementations. RING 使用gRPC作为跨主机通信层实现基于环的集合。 NCCL 使用 Nvidia's NCCL 来实现集合。 AUTO 将推迟到运行时选择。集体实施的最佳选择取决于GPU的数量和种类以及集群中的网络互连。 要覆盖自动选择,请为 MultiWorkerMirroredStrategy 的构造函数的 communication 参数指定一个有效值,例如: communication=tf.distribute.experimental.CollectiveCommunication.NCCL.
  • 如果可能的话,将变量强制转换为 tf.float。ResNet 的官方模型包括如何完成此操作的示例

容错能力

在同步训练中,如果其中一个工作器出现故障并且不存在故障恢复机制,则集群将失败。 在工作器退出或不稳定的情况下,将 Keras 与 tf.distribute.Strategy 一起使用会具有容错的优势。 我们通过在您选择的分布式文件系统中保留训练状态来做到这一点,以便在重新启动先前失败或被抢占的实例后,将恢复训练状态。

由于所有工作器在训练 epochs 和 steps 方面保持同步,因此其他工作器将需要等待失败或被抢占的工作器重新启动才能继续。

ModelCheckpoint 回调

要在多工作器训练中利用容错功能,请在调用 tf.keras.Model.fit() 时提供一个 tf.keras.callbacks.ModelCheckpoint 实例。 回调会将检查点和训练状态存储在与 ModelCheckpointfilepath 参数相对应的目录中。

# 将 `filepath` 参数替换为在文件系统中所有工作器都能访问的路径。
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets,
                       epochs=3,
                       steps_per_epoch=5,
                       callbacks=callbacks)
2022-06-03 20:07:46.828733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/3
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
5/5 [==============================] - ETA: 0s - loss: 2.3035 - accuracy: 0.1531
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.trackable_utils has been moved to tensorflow.python.trackable.trackable_utils. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.trackable_utils has been moved to tensorflow.python.trackable.trackable_utils. The old module will be deleted in version 2.11.
INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets
INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets
5/5 [==============================] - 6s 202ms/step - loss: 2.3035 - accuracy: 0.1531
Epoch 2/3
5/5 [==============================] - ETA: 0s - loss: 2.3003 - accuracy: 0.1656
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets
INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets
5/5 [==============================] - 1s 166ms/step - loss: 2.3003 - accuracy: 0.1656
Epoch 3/3
5/5 [==============================] - ETA: 0s - loss: 2.2974 - accuracy: 0.1406
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets
INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets
5/5 [==============================] - 1s 167ms/step - loss: 2.2974 - accuracy: 0.1406
2022-06-03 20:07:54.384936: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-06-03 20:07:54.388366: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
<keras.callbacks.History at 0x7fc2747485b0>

如果某个工作线程被抢占,则整个集群将暂停,直到重新启动被抢占的工作线程为止。工作器重新加入集群后,其他工作器也将重新启动。 现在,每个工作器都将读取先前保存的检查点文件,并获取其以前的状态,从而使群集能够恢复同步,然后继续训练。

如果检查包含在ModelCheckpoint 中指定的 filepath 的目录,则可能会注意到一些临时生成的检查点文件。 这些文件是恢复以前丢失的实例所必需的,并且在成功退出多工作器训练后,这些文件将在 tf.keras.Model.fit() 的末尾被库删除。

您可以查阅

  1. Distributed Training in TensorFlow 该指南概述了可用的分布式策略。
  2. ResNet50 官方模型,该模型可以使用 MirroredStrategyMultiWorkerMirroredStrategy 进行训练