使用 tf.distribute.Strategy 进行自定义训练

在 TensorFlow.org 上查看 在 Google Colab 上运行 在 GitHub 上查看源代码 下载该 notebook

本教程演示了如何使用 tf.distribute.Strategy 来进行自定义训练循环。 我们将在流行的 MNIST 数据集上训练一个简单的 CNN 模型。 流行的 MNIST 数据集包含了 60000 张尺寸为 28 x 28 的训练图像和 10000 张尺寸为 28 x 28 的测试图像。

我们用自定义训练循环来训练我们的模型是因为它们在训练的过程中为我们提供了灵活性和在训练过程中更好的控制。而且,使它们调试模型和训练循环的时候更容易。

# 导入 TensorFlow
import tensorflow as tf

# 帮助库
import numpy as np
import os

print(tf.__version__)
2.3.0

下载流行的 MNIST 数据集

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 向数组添加维度 -> 新的维度 == (28, 28, 1)
# 我们这样做是因为我们模型中的第一层是卷积层
# 而且它需要一个四维的输入 (批大小, 高, 宽, 通道).
# 批大小维度稍后将添加。
train_images = train_images[..., None]
test_images = test_images[..., None]

# 获取[0,1]范围内的图像。
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

创建一个分发变量和图形的策略

tf.distribute.MirroredStrategy 策略是如何运作的?

  • 所有变量和模型图都复制在副本上。
  • 输入都均匀分布在副本中。
  • 每个副本在收到输入后计算输入的损失和梯度。
  • 通过求和,每一个副本上的梯度都能同步。
  • 同步后,每个副本上的复制的变量都可以同样更新。

注意:您可以将下面的所有代码放在一个单独单元内。 我们将它分成几个代码单元用于说明目的。

# 如果设备未在 `tf.distribute.MirroredStrategy` 的指定列表中,它会被自动检测到。
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

设置输入流水线

将图形和变量导出成平台不可识别的 SavedModel 格式。在你的模型保存后,你可以在有或没有范围的情况下载入它。

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

创建数据集并分发它们:

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

创建模型

使用 tf.keras.Sequential 创建一个模型。你也可以使用模型子类化 API 来完成这个。

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

  return model
# 创建检查点目录以存储检查点。
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

定义损失函数

通常,在一台只有一个 GPU / CPU 的机器上,损失需要除去输入批量中的示例数。

那么,使用 tf.distribute.Strategy 时应该如何计算损失?

  • 举一个例子,假设您有4个 GPU,批量大小为64. 输入的一个批次分布在各个副本(4个 GPU)上,每个副本获得的输入大小为16。

  • 每个副本上的模型使用其各自的输入执行正向传递并计算损失。 现在,相较于将损耗除以其各自输入中的示例数(BATCH_SIZE_PER_REPLICA = 16),应将损失除以GLOBAL_BATCH_SIZE(64)。

为什么这样做?

  • 需要这样做是因为在每个副本上计算梯度之后,它们通过 summing 来使得在自身在各个副本之间同步。

如何在TensorFlow中执行此操作?

  • 如果您正在编写自定义训练循环,如本教程中所示,您应该将每个示例损失相加并将总和除以 GLOBAL_BATCH_SIZE : scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) 或者你可以使用tf.nn.compute_average_loss 来获取每个示例的损失,可选的样本权重,将GLOBAL_BATCH_SIZE作为参数,并返回缩放的损失。

  • 如果您在模型中使用正则化损失,则需要进行缩放多个副本的损失。 您可以使用tf.nn.scale_regularization_loss函数执行此操作。

  • 建议不要使用tf.reduce_mean。 这样做会将损失除以实际的每个副本中每一步都会改变的批次大小。

  • 这种缩小和缩放是在 keras中 modelcompilemodel.fit中自动完成的

  • 如果使用tf.keras.losses类(如下面这个例子所示),则需要将损失减少明确指定为“NONE”或者“SUM”。 使用 tf.distribute.Strategy 时,AUTOSUM_OVER_BATCH_SIZE 是不能使用的。 不能使用 AUTO 是因为用户应明确考虑到在分布式情况下他们想做的哪些减少是正确的。不能使用SUM_OVER_BATCH_SIZE是因为目前它只按每个副本批次大小进行划分,并按照用户的副本数进行划分,这导致了它们很容易丢失。 因此,我们要求用户要明确这些减少。

with strategy.scope():
  # 将减少设置为“无”,以便我们可以在之后进行这个减少并除以全局批量大小。
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      reduction=tf.keras.losses.Reduction.NONE)
  # 或者使用 loss_fn = tf.keras.losses.sparse_categorical_crossentropy
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

定义衡量指标以跟踪损失和准确性

这些指标可以跟踪测试的损失,训练和测试的准确性。 您可以使用.result()随时获取累积的统计信息。

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

训练循环

# 必须在`strategy.scope`下创建模型和优化器。
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
with strategy.scope():
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(labels, predictions)
    return loss 

  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)
with strategy.scope():
  # `experimental_run_v2`将复制提供的计算并使用分布式输入运行它。
  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.experimental_run_v2(train_step,
                                                      args=(dataset_inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None)
 
  @tf.function
  def distributed_test_step(dataset_inputs):
    return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))

  for epoch in range(EPOCHS):
    # 训练循环
    total_loss = 0.0
    num_batches = 0
    for x in train_dist_dataset:
      total_loss += distributed_train_step(x)
      num_batches += 1
    train_loss = total_loss / num_batches

    # 测试循环
    for x in test_dist_dataset:
      distributed_test_step(x)

    if epoch % 2 == 0:
      checkpoint.save(checkpoint_prefix)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print (template.format(epoch+1, train_loss,
                           train_accuracy.result()*100, test_loss.result(),
                           test_accuracy.result()*100))

    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
WARNING:tensorflow:From <ipython-input-1-6439d0e9d271>:5: StrategyBase.experimental_run_v2 (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
renamed to `run`
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.5272247791290283, Accuracy: 80.95500183105469, Test Loss: 0.39799919724464417, Test Accuracy: 86.08000183105469
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2, Loss: 0.3536641597747803, Accuracy: 87.19000244140625, Test Loss: 0.3652512729167938, Test Accuracy: 86.79999542236328
Epoch 3, Loss: 0.30651605129241943, Accuracy: 88.96333312988281, Test Loss: 0.35199666023254395, Test Accuracy: 86.76000213623047
Epoch 4, Loss: 0.2756423354148865, Accuracy: 89.99333190917969, Test Loss: 0.2974560558795929, Test Accuracy: 89.1500015258789
Epoch 5, Loss: 0.24928639829158783, Accuracy: 90.86833953857422, Test Loss: 0.28945034742355347, Test Accuracy: 89.31999969482422
Epoch 6, Loss: 0.22822219133377075, Accuracy: 91.66999816894531, Test Loss: 0.2690503001213074, Test Accuracy: 90.13999938964844
Epoch 7, Loss: 0.21215270459651947, Accuracy: 92.19833374023438, Test Loss: 0.2673594057559967, Test Accuracy: 90.37000274658203
Epoch 8, Loss: 0.19466665387153625, Accuracy: 92.86500549316406, Test Loss: 0.280720591545105, Test Accuracy: 90.36000061035156
Epoch 9, Loss: 0.1819683462381363, Accuracy: 93.4000015258789, Test Loss: 0.2655133008956909, Test Accuracy: 90.54000091552734
Epoch 10, Loss: 0.16936612129211426, Accuracy: 93.711669921875, Test Loss: 0.26561689376831055, Test Accuracy: 90.55999755859375

以上示例中需要注意的事项:

  • 我们使用for x in ...迭代构造train_dist_datasettest_dist_dataset
  • 缩放损失是distributed_train_step的返回值。 这个值会在各个副本使用tf.distribute.Strategy.reduce的时候合并,然后通过tf.distribute.Strategy.reduce叠加各个返回值来跨批次。
  • 在执行tf.distribute.Strategy.experimental_run_v2时,tf.keras.Metrics应在train_steptest_step中更新。
  • tf.distribute.Strategy.experimental_run_v2返回策略中每个本地副本的结果,并且有多种方法可以处理此结果。 您可以执行tf.distribute.Strategy.reduce来获取汇总值。 您还可以执行tf.distribute.Strategy.experimental_local_results来获取每个本地副本中结果中包含的值列表。

恢复最新的检查点并进行测试

一个模型使用了tf.distribute.Strategy的检查点可以使用策略或者不使用策略进行恢复。

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 90.54000091552734

迭代一个数据集的替代方法

使用迭代器

如果你想要迭代一个已经给定步骤数量而不需要整个遍历的数据集,你可以创建一个迭代器并在迭代器上调用iter和显式调用next。 您可以选择在 tf.function 内部和外部迭代数据集。 这是一个小片段,演示了使用迭代器在 tf.function 外部迭代数据集。

with strategy.scope():
  for _ in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    train_iter = iter(train_dist_dataset)

    for _ in range(10):
      total_loss += distributed_train_step(next(train_iter))
      num_batches += 1
    average_train_loss = total_loss / num_batches

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
    train_accuracy.reset_states()
Epoch 10, Loss: 0.17099234461784363, Accuracy: 93.75
Epoch 10, Loss: 0.12641692161560059, Accuracy: 95.9375
Epoch 10, Loss: 0.11636483669281006, Accuracy: 96.09375
Epoch 10, Loss: 0.1404765546321869, Accuracy: 95.0
Epoch 10, Loss: 0.16838286817073822, Accuracy: 92.5
Epoch 10, Loss: 0.1905607134103775, Accuracy: 93.125
Epoch 10, Loss: 0.12706035375595093, Accuracy: 95.78125
Epoch 10, Loss: 0.14852401614189148, Accuracy: 93.59375
Epoch 10, Loss: 0.11990274488925934, Accuracy: 95.9375
Epoch 10, Loss: 0.1237613782286644, Accuracy: 95.9375

在 tf.function 中迭代

您还可以使用for x in ...构造在 tf.function 内部迭代整个输入train_dist_dataset,或者像上面那样创建迭代器。下面的例子演示了在 tf.function 中包装一个 epoch 并在功能内迭代train_dist_dataset

with strategy.scope():
  @tf.function
  def distributed_train_epoch(dataset):
    total_loss = 0.0
    num_batches = 0
    for x in dataset:
      per_replica_losses = strategy.experimental_run_v2(train_step,
                                                        args=(x,))
      total_loss += strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
      num_batches += 1
    return total_loss / tf.cast(num_batches, dtype=tf.float32)

  for epoch in range(EPOCHS):
    train_loss = distributed_train_epoch(train_dist_dataset)

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

    train_accuracy.reset_states()
Epoch 1, Loss: 0.1545342057943344, Accuracy: 94.34666442871094
Epoch 2, Loss: 0.14368833601474762, Accuracy: 94.76666259765625
Epoch 3, Loss: 0.13302761316299438, Accuracy: 95.22833251953125
Epoch 4, Loss: 0.12302733212709427, Accuracy: 95.51499938964844
Epoch 5, Loss: 0.11504675447940826, Accuracy: 95.7300033569336
Epoch 6, Loss: 0.10611504316329956, Accuracy: 96.02000427246094
Epoch 7, Loss: 0.09776321798563004, Accuracy: 96.3566665649414
Epoch 8, Loss: 0.0923474133014679, Accuracy: 96.54166412353516
Epoch 9, Loss: 0.08583918958902359, Accuracy: 96.85833740234375
Epoch 10, Loss: 0.0784970372915268, Accuracy: 97.12332916259766

跟踪副本中的训练的损失

注意:作为通用的规则,您应该使用tf.keras.Metrics来跟踪每个样本的值以避免它们在副本中合并。

我们 建议使用tf.metrics.Mean 来跟踪不同副本的训练损失,因为在执行过程中会进行损失缩放计算。

例如,如果您运行具有以下特点的训练作业:

  • 两个副本
  • 在每个副本上处理两个例子
  • 产生的损失值:每个副本为[2,3]和[4,5]
  • 全局批次大小 = 4

通过损失缩放,您可以通过添加损失值来计算每个副本上的每个样本的损失值,然后除以全局批量大小。 在这种情况下:(2 + 3)/ 4 = 1.25(4 + 5)/ 4 = 2.25

如果您使用 tf.metrics.Mean 来跟踪两个副本的损失,结果会有所不同。 在这个例子中,你最终得到一个total为 3.50 和count为 2 的结果,当调用result()时,你将得到total /count = 1.75。 使用tf.keras.Metrics计算损失时会通过一个等于同步副本数量的额外因子来缩放。

例子和教程

以下是一些使用自定义训练循环来分发策略的示例:

  1. 教程 使用 MirroredStrategy 来训练 MNIST 。
  2. DenseNet 使用 MirroredStrategy的例子。
  3. BERT 使用 MirroredStrategyTPUStrategy来训练的例子。 此示例对于了解如何在分发训练过程中如何载入一个检测点和定期生成检查点特别有帮助。
  4. NCF 使用 MirroredStrategy 来启用 keras_use_ctl 标记。
  5. NMT 使用 MirroredStrategy来训练的例子。

更多的例子列在 分发策略指南

下一步

在你的模型上尝试新的tf.distribute.Strategy API。