迁移多工作进程 CPU/GPU 训练

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

本指南演示了如何将多工作进程分布式训练工作流从 TensorFlow 1 迁移到 TensorFlow 2。

使用 CPU/GPU 执行多工作进程训练:

安装

从一些必要的导入和用于演示目的的简单数据集开始:

# The notebook uses a dataset instance for `Model.fit` with
# `ParameterServerStrategy`, which depends on symbols in TF 2.7.
# Install a utility needed for this demonstration
!pip install portpicker

import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-08-31 00:36:52.723176: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-31 00:36:53.448692: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 00:36:53.448979: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 00:36:53.448993: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

您将需要 'TF_CONFIG' 配置环境变量以在 TensorFlow 中的多台机器上进行训练。使用 'TF_CONFIG' 指定 'cluster''task' 的地址。(有关详情,请参阅分布式训练指南。)

import json
import os

tf_config = {
    'cluster': {
        'chief': ['localhost:11111'],
        'worker': ['localhost:12345', 'localhost:23456', 'localhost:21212'],
        'ps': ['localhost:12121', 'localhost:13131'],
    },
    'task': {'type': 'chief', 'index': 0}
}

os.environ['TF_CONFIG'] = json.dumps(tf_config)

注:遗憾的是,由于在 TensorFlow 1 中使用 tf.estimator API 进行多工作进程训练需要多个客户端(在此 Colab 笔记本中实现这一点会特别棘手),这会使笔记本在没有 'TF_CONFIG' 环境变量的情况下可运行,因此它会回退为本地训练。(有关详情,请参阅使用 TensorFlow 进行分布式训练指南中的设置 'TF_CONFIG' 环境变量部分。)

使用 del 语句移除变量(但在 TensorFlow 1 中的实际多工作进程训练中,您不必这样做):

del os.environ['TF_CONFIG']

TensorFlow 1:使用 tf.estimator API 进行多工作进程分布式训练

以下代码段演示了 TF1 中多工作进程训练的规范工作流:您将使用 tf.estimator.Estimatortf.estimator.TrainSpectf.estimator.EvalSpectf.estimator.train_and_evaluate API 来分布训练:

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
train_spec = tf1.estimator.TrainSpec(input_fn=_input_fn)
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp3_hegyqd
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp3_hegyqd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp3_hegyqd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.4880586, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmp3_hegyqd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-31T00:36:58
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp3_hegyqd/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.15916s
INFO:tensorflow:Finished evaluation at 2022-08-31-00:36:58
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 43.62651
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmp3_hegyqd/model.ckpt-3
INFO:tensorflow:Loss for final step: 20.31957.
({'loss': 43.62651, 'global_step': 3}, [])

TensorFlow 2:使用分布策略进行多工作进程训练

在 TensorFlow 2 中,使用 CPU、GPU 和 TPU 的多个工作进程之间的分布式训练是通过 tf.distribute.Strategy 完成的。

下面的示例演示了如何使用两种这样的策略:tf.distribute.experimental.ParameterServerStrategytf.distribute.MultiWorkerMirroredStrategy,这两种策略都是为使用多个工作进程进行 CPU/GPU 训练而设计的。

ParameterServerStrategy 使用了一个协调器 ('chief'),这使其对此 Colab 笔记本中的环境更加友好。您将在此处使用一些效用函数来设置可运行体验所必需的支持元素:您将创建一个进程内聚簇,其中线程用于模拟参数服务器 ('ps') 和工作进程 ('worker')。有关参数服务器训练的更多信息,请参阅使用 ParameterServerStrategy 进行参数服务器训练教程。

在此示例中,首先使用 tf.distribute.cluster_resolver.TFConfigClusterResolver 定义 'TF_CONFIG' 环境变量以提供聚簇信息。如果您使用聚簇管理系统进行分布式训练,请检查它是否已经提供了 'TF_CONFIG',如果已提供,则无需显式设置此环境变量。(有关详情,请参阅使用 TensorFlow 进行分布式训练指南中的设置 'TF_CONFIG' 环境变量部分。)

# Find ports that are available for the `'chief'` (the coordinator),
# `'worker'`s, and `'ps'` (parameter servers).
import portpicker

chief_port = portpicker.pick_unused_port()
worker_ports = [portpicker.pick_unused_port() for _ in range(3)]
ps_ports = [portpicker.pick_unused_port() for _ in range(2)]

# Dump the cluster information to `'TF_CONFIG'`.
tf_config = {
    'cluster': {
        'chief': ["localhost:%s" % chief_port],
        'worker': ["localhost:%s" % port for port in worker_ports],
        'ps':  ["localhost:%s" % port for port in ps_ports],
    },
    'task': {'type': 'chief', 'index': 0}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)

# Use a cluster resolver to bridge the information to the strategy created below.
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()

然后,为工程进程和参数服务器一一创建 tf.distribute.Server

# Workers need some inter_ops threads to work properly.
# This is only needed for this notebook to demo. Real servers
# should not need this.
worker_config = tf.compat.v1.ConfigProto()
worker_config.inter_op_parallelism_threads = 4

for i in range(3):
  tf.distribute.Server(
      cluster_resolver.cluster_spec(),
      job_name="worker",
      task_index=i,
      config=worker_config)

for i in range(2):
  tf.distribute.Server(
      cluster_resolver.cluster_spec(),
      job_name="ps",
      task_index=i)

在现实世界的分布式训练中,您将使用多台机器而不是启动协调器上的所有 tf.distribute.Server,并且指定为"worker""ps"(参数服务器)的机器将各自运行一个 tf.distribute.Server。请参阅参数服务器训练教程中的现实世界中的聚簇部分,了解更多详细信息。

一切准备就绪后,创建 ParameterServerStrategy 对象:

strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
INFO:tensorflow:`tf.distribute.experimental.ParameterServerStrategy` is initialized with cluster_spec: ClusterSpec({'chief': ['localhost:43271'], 'ps': ['localhost:46705', 'localhost:45915'], 'worker': ['localhost:39667', 'localhost:38207', 'localhost:35449']})
INFO:tensorflow:ParameterServerStrategyV2 is now connecting to cluster with cluster_spec: ClusterSpec({'chief': ['localhost:43271'], 'ps': ['localhost:46705', 'localhost:45915'], 'worker': ['localhost:39667', 'localhost:38207', 'localhost:35449']})
INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:chief/replica:0/task:0/device:GPU:0', '/job:chief/replica:0/task:0/device:GPU:1', '/job:chief/replica:0/task:0/device:GPU:2', '/job:chief/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0'
INFO:tensorflow:Number of GPUs on workers: 4

创建策略对象后,定义模型、优化器和其他变量,然后在 Strategy.scope API 中调用 Keras Model.compile 以分布训练。(如需了解详情,请参阅 Strategy.scope API 文档。)

如果您更喜欢通过定义前向和后向传递来自定义训练,请参阅参数服务器训练教程中的使用自定义训练循环进行训练部分,了解更多详细信息。

dataset = tf.data.Dataset.from_tensor_slices(
      (features, labels)).shuffle(10).repeat().batch(64)

eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).repeat().batch(1)

with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
  optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
  model.compile(optimizer, "mse")

model.fit(dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
2022-08-31 00:37:02.621338: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:6"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-08-31 00:37:02.626112: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:6"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-08-31 00:37:02.628310: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:6"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}
10/10 - 5s - loss: 4.6559 - 5s/epoch - 535ms/step
Epoch 2/5
10/10 - 0s - loss: 1.6491 - 66ms/epoch - 7ms/step
Epoch 3/5
10/10 - 0s - loss: 0.7956 - 66ms/epoch - 7ms/step
Epoch 4/5
10/10 - 0s - loss: 0.4192 - 63ms/epoch - 6ms/step
Epoch 5/5
10/10 - 0s - loss: 0.2293 - 62ms/epoch - 6ms/step
<keras.callbacks.History at 0x7fda8c33c7f0>
model.evaluate(eval_dataset, steps=10, return_dict=True)
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
2022-08-31 00:37:06.670870: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:10"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-08-31 00:37:06.697882: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:10"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-08-31 00:37:06.700059: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 3
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:10"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 2
        }
      }
      shape {
        dim {
          size: 1
        }
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}
10/10 - 2s - loss: 1.2941 - 2s/epoch - 212ms/step
{'loss': 1.2941229343414307}

分区器 (tf.distribute.experimental.partitioners)

TensorFlow 2 中的 ParameterServerStrategy 支持变量分区,并提供与 TensorFlow 1 相同,但名称不容易混淆的分区器:

或者,您也可以使用 MultiWorkerMirroredStrategy 对象:

# To clean up the `TF_CONFIG` used for `ParameterServerStrategy`.
del os.environ['TF_CONFIG']
strategy = tf.distribute.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO

可以将上面使用的策略替换为 MultiWorkerMirroredStrategy 对象,以使用此策略执行训练。

tf.estimator API 一样,由于 MultiWorkerMirroredStrategy 是一种多客户端策略,在此 Colab 笔记本中运行分布式训练并不容易。因此,用这种策略替换上面的代码最终会以在本地运行结束。使用 Keras Model.fit/自定义训练循环的多工作进程训练教程演示了在设置 'TF_CONFIG' 变量的情况下,如何在 Colab 的本地主机上使用两个工作进程运行多工作进程训练。在实践中,您将在外部 IP 地址/端口上创建多个工作进程,并使用 'TF_CONFIG' 变量为每个工作进程指定聚簇配置。

后续步骤

要详细了解如何在 TensorFlow 2 中使用 tf.distribute.experimental.ParameterServerStrategytf.distribute.MultiWorkerMirroredStrategy 进行多工作进程分布式训练,请查看以下资源: