利用 Estimator 进行多工作进程训练

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

警告:不建议将 Estimator 用于新代码。Estimator 运行 v1.Session 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的兼容性保证范围内,但除了安全漏洞之外不会得到任何修复。请参阅迁移指南以了解详情。

概述

注:虽然您可以将 Estimator 与 tf.distribute API 结合使用,但建议将 Keras 和 tf.distribute 结合使用,请参阅利用 Keras 进行多工作进程训练。使用 tf.distribute.Strategy 的 Estimator 训练支持有限。

本教程展示了如何通过 tf.estimatortf.distribute.Strategy 用于分布式多工作进程训练。如果您使用 tf.estimator 编写代码,并且希望以高性能扩展到更多机器,那么本教程很适合您。

在开始之前,请先阅读分布式策略指南。同样相关的还有多 GPU 训练教程,因为本教程使用的是相同的模型。

创建

首先,设置好 TensorFlow 以及将会用到的输入模块。

import tensorflow_datasets as tfds
import tensorflow as tf

import os, json
2023-11-07 23:13:06.920204: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:13:06.920259: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:13:06.921969: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

注:从 TF2.4 开始,如果在启用了 Eager(默认)的情况下运行,多工作进程镜像策略会在使用 Estimator 时失败。TF2.4 中的错误是 TypeError: cannot pickle '_thread.lock' object。请参阅议题 #46556 了解详细信息。解决办法是禁用 Eager Execution。

tf.compat.v1.disable_eager_execution()

输入函数

本教程使用的是 TensorFlow 数据集中的 MNIST 数据集。本教程中的代码与多 GPU 训练教程类似,但有一个主要区别:当使用 Estimator 进行多工作进程训练时,需要根据工作进程的数量对数据集进行拆分,以确保模型收敛。输入数据会根据工作进程索引来拆分,因此每个工作进程负责处理数据集的 1/num_workers 个不同部分。

BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

使模型收敛的另一种合理方式是在每个工作器上设置不同的随机种子,然后对数据集进行随机重排。

多工作器配置

本教程主要的不同(区别于使用多 GPU 训练教程)在于多工作器的创建。明确集群中每个工作器的配置的标准方式是设置环境变量 TF_CONFIG

TF_CONFIG 里包括了两个部分:clustertaskcluster 提供了关于整个集群的信息,也就是集群中的工作器和参数服务器(parameter server)。task 提供了关于当前任务的信息。在本例中,任务的类型(type)是 worker 且该任务的索引(index)是 0。

出于演示的目的,本教程展示了怎么将 TF_CONFIG 设置成两个本地的工作器。在实践中,你可以在外部的IP地址和端口上创建多个工作器,并为每个工作器正确地配置好 TF_CONFIG 变量,也就是更改任务的索引。

警告:请勿在 Colab 中执行以下代码。TensorFlow 的运行时将尝试在指定的 IP 地址和端口上创建 gRPC 服务器,而这可能会失败。请参阅本教程的 Keras 版本,查看说明如何在单台计算机上测试运行多个工作进程的示例。

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

定义模型

定义训练中用到的层,优化器和损失函数。本教程使用 Keras layers 定义模型,同使用多 GPU 训练教程类似。

LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

注:尽管在本例中学习率是固定的,但通常情况下可能有必要根据全局批次大小对学习率进行调整。

MultiWorkerMirroredStrategy

要训练模型,请使用 tf.distribute.experimental.MultiWorkerMirroredStrategy 的实例。MultiWorkerMirroredStrategy会在所有工作进程的每个设备上的模型的层中创建所有变量的副本。它使用 CollectiveOps(一种用于集合通信的 TensorFlow 运算)来聚合梯度并确保变量同步。tf.distribute.Strategy 指南中有关于此策略的更多详细信息。

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3')
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO

训练和评估模型

接下来,在 RunConfig 中为 Estimator 指定分布式策略,并通过调用 tf.estimator.train_and_evaluate 进行训练和评估。本教程通过 train_distribute 指定策略来分布训练。也可以通过 eval_distribute 来分布评估。

config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:1: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:3: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental object at 0x7efcd0028040>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:7: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:8: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:5: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1246: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
use `update_config_proto` instead.
INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
2023-11-07 23:13:17.215155: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.216474: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.225235: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.226238: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.249919: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.250960: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.258385: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2023-11-07 23:13:17.259003: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 2.308822, step = 0
INFO:tensorflow:loss = 2.308822, step = 0
INFO:tensorflow:global_step/sec: 153.52
INFO:tensorflow:global_step/sec: 153.52
INFO:tensorflow:loss = 2.297161, step = 100 (0.654 sec)
INFO:tensorflow:loss = 2.297161, step = 100 (0.654 sec)
INFO:tensorflow:global_step/sec: 211.129
INFO:tensorflow:global_step/sec: 211.129
INFO:tensorflow:loss = 2.2924602, step = 200 (0.473 sec)
INFO:tensorflow:loss = 2.2924602, step = 200 (0.473 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234...
INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-11-07T23:13:22
INFO:tensorflow:Starting evaluation at 2023-11-07T23:13:22
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Inference Time : 1.34731s
INFO:tensorflow:Inference Time : 1.34731s
INFO:tensorflow:Finished evaluation at 2023-11-07-23:13:23
INFO:tensorflow:Finished evaluation at 2023-11-07-23:13:23
INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.2986863
INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.2986863
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Loss for final step: 2.2939153.
INFO:tensorflow:Loss for final step: 2.2939153.
({'loss': 2.2986863, 'global_step': 234}, [])

优化训练性能

您现在有了一个模型和一个由 tf.distribute.Strategy 驱动的支持多工作进程的 Estimator。您可以尝试使用以下技巧来优化多工作进程训练的性能:

  • 增加批次大小:这里指定的批次大小是按 GPU 计算的。通常,建议使用适合 GPU 内存的最大批次大小。

  • 强制转换变量:尽可能将变量强制转换为 tf.float。官方 ResNet 模型包括一个说明如何实现的示例

  • 使用集合通信MultiWorkerMirroredStrategy 提供了多个集合通信实现

    • RING 使用 gRPC 作为跨主机通信层实现了基于环的集合。
    • NCCL 使用 NVIDIA 的 NCCL 来实现集合。
    • AUTO 将选择推迟到运行时。

    集合实现的最佳选择取决于 GPU 的数量和种类,以及集群中的网络互连。要重写自动选择,请为 MultiWorkerMirroredStrategy 的构造函数的 communication 参数指定有效值,例如 communication=tf.distribute.experimental.CollectiveCommunication.NCCL

访问指南中的性能部分,了解有关其他策略和工具的更多信息,您可以使用它们来优化 TensorFlow 模型的性能。

更多的代码示例

  1. tensorflow/ecosystem 中的端到端示例,使用 Kubernetes 模板进行多工作进程训练。该示例以 Keras 模型开始,并使用 tf.keras.estimator.model_to_estimator API 将其转换为 Estimator。
  2. 官方模型,其中许多模型可以配置为运行多个分布式策略。