Missed TensorFlow World? Check out the recap. Learn more

Keras 的分布式训练

在 tensorflow.google.cn 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载此 notebook

概述

tf.distribute.Strategy API 提供了一个抽象的 API ,用于跨多个处理单元(processing units)分布式训练。它的目的是允许用户使用现有模型和训练代码,只需要很少的修改,就可以启用分布式训练。

本教程使用 tf.distribute.MirroredStrategy,这是在一台计算机上的多 GPU(单机多卡)进行同时训练的图形内复制(in-graph replication)。事实上,它会将所有模型的变量复制到每个处理器上,然后,通过使用 all-reduce 去整合所有处理器的梯度(gradients),并将整合的结果应用于所有副本之中。

MirroredStategy 是 tensorflow 中可用的几种分发策略之一。 您可以在 分发策略指南 中阅读更多分发策略。

Keras API

这个例子使用 tf.keras API 去构建和训练模型。 关于自定义训练模型,请参阅 tf.distribute.Strategy with training loops 教程。

导入依赖

from __future__ import absolute_import, division, print_function, unicode_literals

# 导入 TensorFlow 和 TensorFlow 数据集
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

下载数据集

下载 MNIST 数据集并从 TensorFlow Datasets 加载。 这会返回 tf.data 格式的数据集。

with_info 设置为 True 会包含整个数据集的元数据,其中这些数据集将保存在 info 中。 除此之外,该元数据对象包括训练和测试示例的数量。

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...

/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

定义分配策略

创建一个 MirroredStrategy 对象。这将处理分配策略,并提供一个上下文管理器(tf.distribute.MirroredStrategy.scope)来构建你的模型。

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

设置输入管道(pipeline)

在训练具有多个 GPU 的模型时,您可以通过增加批量大小(batch size)来有效地使用额外的计算能力。通常来说,使用适合 GPU 内存的最大批量大小(batch size),并相应地调整学习速率。

# 您还可以执行 info.splits.total_num_examples 来获取总数
# 数据集中的样例数量。

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

0-255 的像素值, 必须标准化到 0-1 范围。在函数中定义标准化。

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

将此功能应用于训练和测试数据,随机打乱训练数据,并批量训练。 请注意,我们还保留了训练数据的内存缓存以提高性能。

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

生成模型

strategy.scope 的上下文中创建和编译 Keras 模型。

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

定义回调(callback)

这里使用的回调(callbacks)是:

  • TensorBoard: 此回调(callbacks)为 TensorBoard 写入日志,允许您可视化图形。
  • Model Checkpoint: 此回调(callbacks)在每个 epoch 后保存模型。
  • Learning Rate Scheduler: 使用此回调(callbacks),您可以安排学习率在每个 epoch/batch 之后更改。

为了便于说明,添加打印回调(callbacks)以在笔记本中显示学习率

# 定义检查点(checkpoint)目录以存储检查点(checkpoints)

checkpoint_dir = './training_checkpoints'
# 检查点(checkpoint)文件的名称
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# 衰减学习率的函数。
# 您可以定义所需的任何衰减函数。
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# 在每个 epoch 结束时打印LR的回调(callbacks)。
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

训练和评估

在该部分,以普通的方式训练模型,在模型上调用 fit 并传入在教程开始时创建的数据集。 无论您是否分布式训练,此步骤都是相同的。

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

    938/Unknown - 10s 11ms/step - loss: 0.2104 - accuracy: 0.9394
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 10s 11ms/step - loss: 0.2104 - accuracy: 0.9394
Epoch 2/12
923/938 [============================>.] - ETA: 0s - loss: 0.0715 - accuracy: 0.9786
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 2s 2ms/step - loss: 0.0714 - accuracy: 0.9786
Epoch 3/12
917/938 [============================>.] - ETA: 0s - loss: 0.0501 - accuracy: 0.9847
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 2s 3ms/step - loss: 0.0502 - accuracy: 0.9847
Epoch 4/12
921/938 [============================>.] - ETA: 0s - loss: 0.0276 - accuracy: 0.9925
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 2s 2ms/step - loss: 0.0276 - accuracy: 0.9925
Epoch 5/12
924/938 [============================>.] - ETA: 0s - loss: 0.0243 - accuracy: 0.9934
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 2s 2ms/step - loss: 0.0245 - accuracy: 0.9934
Epoch 6/12
930/938 [============================>.] - ETA: 0s - loss: 0.0226 - accuracy: 0.9941
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 2s 2ms/step - loss: 0.0226 - accuracy: 0.9941
Epoch 7/12
925/938 [============================>.] - ETA: 0s - loss: 0.0209 - accuracy: 0.9947
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 2s 2ms/step - loss: 0.0209 - accuracy: 0.9947
Epoch 8/12
926/938 [============================>.] - ETA: 0s - loss: 0.0183 - accuracy: 0.9958
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 2s 2ms/step - loss: 0.0183 - accuracy: 0.9958
Epoch 9/12
923/938 [============================>.] - ETA: 0s - loss: 0.0181 - accuracy: 0.9959
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 2s 2ms/step - loss: 0.0181 - accuracy: 0.9960
Epoch 10/12
930/938 [============================>.] - ETA: 0s - loss: 0.0179 - accuracy: 0.9960
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 2s 3ms/step - loss: 0.0179 - accuracy: 0.9960
Epoch 11/12
925/938 [============================>.] - ETA: 0s - loss: 0.0177 - accuracy: 0.9960
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 2s 2ms/step - loss: 0.0177 - accuracy: 0.9961
Epoch 12/12
918/938 [============================>.] - ETA: 0s - loss: 0.0174 - accuracy: 0.9960
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 2s 2ms/step - loss: 0.0175 - accuracy: 0.9960

<tensorflow.python.keras.callbacks.History at 0x7f9ba25bdc50>

如下所示,检查点(checkpoint)将被保存。

# 检查检查点(checkpoint)目录
!ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00002
ckpt_1.data-00000-of-00002   ckpt_4.data-00001-of-00002
ckpt_1.data-00001-of-00002   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00002
ckpt_10.data-00000-of-00002  ckpt_5.data-00001-of-00002
ckpt_10.data-00001-of-00002  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00002
ckpt_11.data-00000-of-00002  ckpt_6.data-00001-of-00002
ckpt_11.data-00001-of-00002  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00002
ckpt_12.data-00000-of-00002  ckpt_7.data-00001-of-00002
ckpt_12.data-00001-of-00002  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00002
ckpt_2.data-00000-of-00002   ckpt_8.data-00001-of-00002
ckpt_2.data-00001-of-00002   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00002
ckpt_3.data-00000-of-00002   ckpt_9.data-00001-of-00002
ckpt_3.data-00001-of-00002   ckpt_9.index
ckpt_3.index

要查看模型的执行方式,请加载最新的检查点(checkpoint)并在测试数据上调用 evaluate

使用适当的数据集调用 evaluate

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

157/157 [==============================] - 2s 15ms/step - loss: 0.0390 - accuracy: 0.9865
Eval loss: 0.039015309934083156, Eval Accuracy: 0.9865000247955322

要查看输出,您可以在终端下载并查看 TensorBoard 日志。

$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 4.0K
4.0K train

导出到 SavedModel

将图形和变量导出为与平台无关的 SavedModel 格式。 保存模型后,可以在有或没有 scope 的情况下加载模型。

path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
WARNING:tensorflow:From <ipython-input-19-7f22af6799f5>:1: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.

WARNING:tensorflow:From <ipython-input-19-7f22af6799f5>:1: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Train: ['train']

INFO:tensorflow:Signatures INCLUDED in export for Train: ['train']

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Export includes no default signature!

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']

INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Export includes no default signature!

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb

INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb

在无需 strategy.scope 加载模型。

unreplicated_model = tf.keras.experimental.load_from_saved_model(path)

unreplicated_model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
WARNING:tensorflow:From <ipython-input-20-b1ff61ed84c5>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.

WARNING:tensorflow:From <ipython-input-20-b1ff61ed84c5>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.

157/157 [==============================] - 1s 8ms/step - loss: 0.0390 - accuracy: 0.9865
Eval loss: 0.039015309934083156, Eval Accuracy: 0.9865000247955322

在含 strategy.scope 加载模型。

with strategy.scope():
  replicated_model = tf.keras.experimental.load_from_saved_model(path)
  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

157/157 [==============================] - 2s 15ms/step - loss: 0.0390 - accuracy: 0.9865
Eval loss: 0.039015309934083156, Eval Accuracy: 0.9865000247955322

示例和教程

以下是使用 keras fit/compile 分布式策略的一些示例: 1. 使用tf.distribute.MirroredStrategy 训练 Transformer 的示例。 2. 使用tf.distribute.MirroredStrategy 训练 NCF 的示例。

分布式策略指南中列出的更多示例

下一步

注意:tf.distribute.Strategy 正在积极开发中,我们将在不久的将来添加更多示例和教程。欢迎您进行尝试。我们欢迎您通过 GitHub 上的 issue 提供反馈。