Migrate multi-worker CPU/GPU training

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This guide demonstrates how to migrate your multi-worker distributed training workflow from TensorFlow 1 to TensorFlow 2.

To perform multi-worker training with CPUs/GPUs:

Setup

Start with some necessary imports and a simple dataset for demonstration purposes:

# The notebook uses a dataset instance for `Model.fit` with
# `ParameterServerStrategy`, which depends on symbols in TF 2.7.
# Install a utility needed for this demonstration
!pip install portpicker

import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

You will need the 'TF_CONFIG' configuration environment variable for training on multiple machines in TensorFlow. Use 'TF_CONFIG' to specify the 'cluster' and the 'task's' addresses. (Learn more in the Distributed_training guide.)

import json
import os

tf_config = {
    'cluster': {
        'chief': ['localhost:11111'],
        'worker': ['localhost:12345', 'localhost:23456', 'localhost:21212'],
        'ps': ['localhost:12121', 'localhost:13131'],
    },
    'task': {'type': 'chief', 'index': 0}
}

os.environ['TF_CONFIG'] = json.dumps(tf_config)

Use the del statement to remove the variable (but in real-world multi-worker training in TensorFlow 1, you won't have to do this):

del os.environ['TF_CONFIG']

TensorFlow 1: Multi-worker distributed training with tf.estimator APIs

The following code snippet demonstrates the canonical workflow of multi-worker training in TF1: you will use a tf.estimator.Estimator, a tf.estimator.TrainSpec, a tf.estimator.EvalSpec, and the tf.estimator.train_and_evaluate API to distribute the training:

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
train_spec = tf1.estimator.TrainSpec(input_fn=_input_fn)
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

TensorFlow 2: Multi-worker training with distribution strategies

In TensorFlow 2, distributed training across multiple workers with CPUs, GPUs, and TPUs is done via tf.distribute.Strategys.

The following example demonstrates how to use two such strategies: tf.distribute.experimental.ParameterServerStrategy and tf.distribute.MultiWorkerMirroredStrategy, both of which are designed for CPU/GPU training with multiple workers.

ParameterServerStrategy employs a coordinator ('chief'), which makes it more friendly with the environment in this Colab notebook. You will be using some utilities here to set up the supporting elements essential for a runnable experience here: you will create an in-process cluster, where threads are used to simulate the parameter servers ('ps') and workers ('worker'). For more information about parameter server training, refer to the Parameter server training with ParameterServerStrategy tutorial.

In this example, first define the 'TF_CONFIG' environment variable with a tf.distribute.cluster_resolver.TFConfigClusterResolver to provide the cluster information. If you are using a cluster management system for your distributed training, check if it provides 'TF_CONFIG' for you already, in which case you don't need to explicitly set this environment variable. (Learn more in the Setting up the 'TF_CONFIG' environment variable section in the Distributed training with TensorFlow guide.)

# Find ports that are available for the `'chief'` (the coordinator),
# `'worker'`s, and `'ps'` (parameter servers).
import portpicker

chief_port = portpicker.pick_unused_port()
worker_ports = [portpicker.pick_unused_port() for _ in range(3)]
ps_ports = [portpicker.pick_unused_port() for _ in range(2)]

# Dump the cluster information to `'TF_CONFIG'`.
tf_config = {
    'cluster': {
        'chief': ["localhost:%s" % chief_port],
        'worker': ["localhost:%s" % port for port in worker_ports],
        'ps':  ["localhost:%s" % port for port in ps_ports],
    },
    'task': {'type': 'chief', 'index': 0}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)

# Use a cluster resolver to bridge the information to the strategy created below.
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()

Then, create tf.distribute.Servers for the workers and parameter servers one-by-one:

# Workers need some inter_ops threads to work properly.
# This is only needed for this notebook to demo. Real servers
# should not need this.
worker_config = tf.compat.v1.ConfigProto()
worker_config.inter_op_parallelism_threads = 4

for i in range(3):
  tf.distribute.Server(
      cluster_resolver.cluster_spec(),
      job_name="worker",
      task_index=i,
      config=worker_config)

for i in range(2):
  tf.distribute.Server(
      cluster_resolver.cluster_spec(),
      job_name="ps",
      task_index=i)

In real-world distributed training, instead of starting all the tf.distribute.Servers on the coordinator, you will be using multiple machines, and the ones that are designated as "worker"s and "ps" (parameter servers) will each run a tf.distribute.Server. Refer to Clusters in the real world section in the Parameter server training tutorial for more details.

With everything ready, create the ParameterServerStrategy object:

strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)

Once you have created a strategy object, define the model, the optimizer, and other variables, and call the Keras Model.compile within the Strategy.scope API to distribute the training. (Refer to the Strategy.scope API docs for more information.)

If you prefer to customize your training by, for instance, defining the forward and backward passes, refer to Training with a custom training loop section in Parameter server training tutorial for more details.

dataset = tf.data.Dataset.from_tensor_slices(
      (features, labels)).shuffle(10).repeat().batch(64)

eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).repeat().batch(1)

with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
  optimizer = tf.keras.optimizers.legacy.Adagrad(learning_rate=0.05)
  model.compile(optimizer, "mse")

model.fit(dataset, epochs=5, steps_per_epoch=10)
model.evaluate(eval_dataset, steps=10, return_dict=True)

Partitioners (tf.distribute.experimental.partitioners)

ParameterServerStrategy in TensorFlow 2 supports variable partitioning and offers same partitioners as TensorFlow 1, with less confusing names: - tf.compat.v1.variable_axis_size_partitioner -> tf.distribute.experimental.partitioners.MaxSizePartitioner: a partitioner that keeps shards under a maximum size). - tf.compat.v1.min_max_variable_partitioner -> tf.distribute.experimental.partitioners.MinSizePartitioner: a partitioner that allocates a minimum size per shard. - tf.compat.v1.fixed_size_partitioner -> tf.distribute.experimental.partitioners.FixedShardsPartitioner: a partitioner that allocates a fixed number of shards.

Alternatively, you can use a MultiWorkerMirroredStrategy object:

# To clean up the `TF_CONFIG` used for `ParameterServerStrategy`.
del os.environ['TF_CONFIG']
strategy = tf.distribute.MultiWorkerMirroredStrategy()

You can replace the strategy used above with a MultiWorkerMirroredStrategy object to perform training with this strategy.

As with the tf.estimator APIs, since MultiWorkerMirroredStrategy is a multi-client strategy, there is no easy way to run distributed training in this Colab notebook. Therefore, replacing the code above with this strategy ends up running things locally. The Multi-worker training with Keras Model.fit/a custom training loop tutorials demonstrate how to run multi-worker training with the 'TF_CONFIG' variable set up, with two workers on a localhost in Colab. In practice, you would create multiple workers on external IP addresses/ports, and use the 'TF_CONFIG' variable to specify the cluster configuration for each worker.

Next steps

To learn more about multi-worker distributed training with tf.distribute.experimental.ParameterServerStrategy and tf.distribute.MultiWorkerMirroredStrategy in TensorFlow 2, consider the following resources: