Keras 的分布式训练

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

在 TensorFlow.org 上查看 在 Google Colab 上运行 在 GitHub 上查看源代码 下载笔记本

概述

tf.distribute.Strategy API 提供了一个抽象,用于跨多个处理单元进行分布式训练。它允许您使用现有模型和训练代码,只需要很少的修改,就可以执行分布式训练。

本教程演示了如何使用 tf.distribute.MirroredStrategy单台机器的多个 GPU 上通过同步训练进行计算图内复制。该策略本质上是将所有模型变量复制到每个处理器。 然后,通过使用全归约来组合所有处理器的梯度,并将组合后的值应用于模型的所有副本。

您将使用 tf.keras API 构建模型并使用 Model.fit 对其进行训练。(要了解使用自定义训练循环和 MirroredStrategy 的分布式训练,请查看此教程。)

MirroredStrategy 在单台机器上的多个 GPU 上训练您的模型。要在多个工作进程的多个 GPU 上进行同步训练,请通过 Keras Model.fit自定义训练循环使用 tf.distribute.MultiWorkerMirroredStrategy。有关其他选项,请参阅分布式训练指南

要了解其他各种策略,请参阅使用 TensorFlow 进行分布式训练指南。

安装

import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.
%load_ext tensorboard
2022-08-31 00:09:26.597478: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-31 00:09:27.287215: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 00:09:27.287448: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 00:09:27.287460: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
print(tf.__version__)
2.10.0-rc3

下载数据集

TensorFlow Datasets 加载 MNIST 数据集。这将返回 tf.data 格式的数据集。

with_info 参数设置为 True 会包含整个数据集的元数据,这些元数据将被保存到 info 中。此外,该元数据对象还包括训练样本和测试样本的数量。

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

定义分布式策略

创建 MirroredStrategy 对象。这将处理分布,并提供一个上下文管理器 (MirroredStrategy.scope) 在内部构建模型。

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4

设置输入流水线

当使用多个 GPU 训练模型时,可以通过增加批次大小来有效利用额外的计算能力。通常,应使用适合 GPU 内存的最大批次大小,并相应地调整学习率。

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

定义一个函数,将图像像素值从 [0, 255] 范围归一化到 [0, 1] 范围(特征缩放):

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

将此 scale 函数应用于训练数据和测试数据,使用 tf.data.Dataset API 对训练数据进行乱序 (Dataset.shuffle),然后进行分批 (Dataset.batch)。请注意,您还保留了训练数据的内存缓存以提高性能 (Dataset.cache).。

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

生成模型

Strategy.scope 的上下文中,使用 Keras API 创建和编译模型:

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

定义回调(callback)

定义以下 Keras 回调

出于说明目的,添加名为 PrintLR 的回调以在笔记本中显示学习率。

注: 使用 BackupAndRestore 回调而不是 ModelCheckpoint 作为从作业失败重新启动时还原训练状态的主要机制。由于 BackupAndRestore 仅支持 Eager 模式,在计算图模式下考虑使用 ModelCheckpoint

# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(        epoch + 1, model.optimizer.lr.numpy()))
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

训练和评估

现在,以普通方式训练模型,在模型上调用 Keras Model.fit 并传入在教程开始时创建的数据集。无论您是否分布训练,此步骤相同。

EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
2022-08-31 00:09:33.652107: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/12
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
1/235 [..............................] - ETA: 29:39 - loss: 2.3001 - accuracy: 0.1719WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0068s vs `on_train_batch_end` time: 0.0104s). Check your callbacks.
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0068s vs `on_train_batch_end` time: 0.0104s). Check your callbacks.
229/235 [============================>.] - ETA: 0s - loss: 0.3490 - accuracy: 0.9006
Learning rate for epoch 1 is 0.0010000000474974513
235/235 [==============================] - 9s 7ms/step - loss: 0.3444 - accuracy: 0.9021 - lr: 0.0010
Epoch 2/12
233/235 [============================>.] - ETA: 0s - loss: 0.1153 - accuracy: 0.9671
Learning rate for epoch 2 is 0.0010000000474974513
235/235 [==============================] - 2s 6ms/step - loss: 0.1151 - accuracy: 0.9672 - lr: 0.0010
Epoch 3/12
233/235 [============================>.] - ETA: 0s - loss: 0.0750 - accuracy: 0.9790
Learning rate for epoch 3 is 0.0010000000474974513
235/235 [==============================] - 1s 6ms/step - loss: 0.0749 - accuracy: 0.9790 - lr: 0.0010
Epoch 4/12
233/235 [============================>.] - ETA: 0s - loss: 0.0533 - accuracy: 0.9859
Learning rate for epoch 4 is 9.999999747378752e-05
235/235 [==============================] - 2s 6ms/step - loss: 0.0532 - accuracy: 0.9859 - lr: 1.0000e-04
Epoch 5/12
235/235 [==============================] - ETA: 0s - loss: 0.0505 - accuracy: 0.9869
Learning rate for epoch 5 is 9.999999747378752e-05
235/235 [==============================] - 2s 6ms/step - loss: 0.0505 - accuracy: 0.9869 - lr: 1.0000e-04
Epoch 6/12
231/235 [============================>.] - ETA: 0s - loss: 0.0487 - accuracy: 0.9874
Learning rate for epoch 6 is 9.999999747378752e-05
235/235 [==============================] - 2s 6ms/step - loss: 0.0486 - accuracy: 0.9874 - lr: 1.0000e-04
Epoch 7/12
233/235 [============================>.] - ETA: 0s - loss: 0.0469 - accuracy: 0.9878
Learning rate for epoch 7 is 9.999999747378752e-05
235/235 [==============================] - 2s 6ms/step - loss: 0.0470 - accuracy: 0.9877 - lr: 1.0000e-04
Epoch 8/12
230/235 [============================>.] - ETA: 0s - loss: 0.0451 - accuracy: 0.9883
Learning rate for epoch 8 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0448 - accuracy: 0.9884 - lr: 1.0000e-05
Epoch 9/12
227/235 [===========================>..] - ETA: 0s - loss: 0.0444 - accuracy: 0.9885
Learning rate for epoch 9 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0444 - accuracy: 0.9886 - lr: 1.0000e-05
Epoch 10/12
231/235 [============================>.] - ETA: 0s - loss: 0.0443 - accuracy: 0.9884
Learning rate for epoch 10 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0442 - accuracy: 0.9885 - lr: 1.0000e-05
Epoch 11/12
228/235 [============================>.] - ETA: 0s - loss: 0.0439 - accuracy: 0.9886
Learning rate for epoch 11 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0440 - accuracy: 0.9886 - lr: 1.0000e-05
Epoch 12/12
234/235 [============================>.] - ETA: 0s - loss: 0.0439 - accuracy: 0.9885
Learning rate for epoch 12 is 9.999999747378752e-06
235/235 [==============================] - 1s 6ms/step - loss: 0.0439 - accuracy: 0.9885 - lr: 1.0000e-05
<keras.callbacks.History at 0x7fbd901c6cd0>

查看保存的检查点:

# Check the checkpoint directory.
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

要查看模型的执行情况,请加载最新的检查点,并在测试数据上调用 Model.evaluate

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
2022-08-31 00:10:02.944412: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 6ms/step - loss: 0.0533 - accuracy: 0.9822
Eval loss: 0.053287554532289505, Eval accuracy: 0.982200026512146

要可视化输出,请启动 TensorBoard 并查看日志:

%tensorboard --logdir=logs

ls -sh ./logs
total 4.0K
4.0K train

导出到 SavedModel

使用 Keras Model.save 将计算图和变量导出为与平台无关的 SavedModel 格式。保存模型后,可以在有或没有 Strategy.scope 的情况下加载模型。

path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets

现在,在没有 Strategy.scope 的情况下加载模型:

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
40/40 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9822
Eval loss: 0.053287554532289505, Eval Accuracy: 0.982200026512146

使用 Strategy.scope 加载模型:

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
2022-08-31 00:10:07.553596: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 4s 5ms/step - loss: 0.0533 - accuracy: 0.9822
Eval loss: 0.053287554532289505, Eval Accuracy: 0.982200026512146

其他资源

更多通过 Keras Model.fit API 使用不同分布策略的示例:

  1. 在 TPU 上使用 BERT 解决 GLUE 任务教程使用 tf.distribute.MirroredStrategy 在 GPU 上进行训练,并使用 tf.distribute.TPUStrategy 在 TPU 上进行训练。
  2. 使用分布式策略保存和加载模型教程演示了如何将 SavedModel API 与 tf.distribute.Strategy 一起使用。
  3. 官方 TensorFlow 模型可以配置为运行多个分布式策略。

要了解有关 TensorFlow 分布式策略的更多信息,请参阅以下资料:

  1. 使用 tf.distribute.Strategy 进行自定义训练教程展示了如何使用 tf.distribute.MirroredStrategy 通过自定义训练循环进行单工作进程训练。
  2. 使用 Keras 进行多工作进程训练教程展示了如何将 MultiWorkerMirroredStrategyModel.fit 一起使用。
  3. 使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环教程展示了如何将 MultiWorkerMirroredStrategy 与 Keras 和自定义训练循环一起使用。
  4. TensorFlow 中的分布式训练指南概述了可用的分布式策略。
  5. 使用 tf.function 获得更佳性能指南提供了有关其他策略和工具的信息,例如可用于优化 TensorFlow 模型性能的 TensorFlow Profiler

注:tf.distribute.Strategy 正在积极开发中,TensorFlow 将在不久的将来添加更多示例和教程。请进行尝试。我们欢迎您通过 GitHub 上的议题提交反馈。