Have a question? Connect with the community at the TensorFlow Forum Visit Forum


A distribution strategy for synchronous training on multiple workers.

Inherits From: MultiWorkerMirroredStrategy, Strategy

Used in the notebooks

Used in the tutorials

This strategy implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to tf.distribute.MirroredStrategy, it replicates all variables and computations to each local device. The difference is that it uses a distributed collective implementation (e.g. all-reduce), so that multiple workers can work together.

You need to launch your program on each worker and configure cluster_resolver correctly. For example, if you are using tf.distribute.cluster_resolver.TFConfigClusterResolver, each worker needs to have its corresponding task_type and task_id set in the TF_CONFIG environment variable. An example TF_CONFIG on worker-0 of a two worker cluster is:

TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

Your program runs on each worker as-is. Note that collectives require each worker to participate. All tf.distribute and non tf.distribute API may use collectives internally, e.g. checkpointing and saving since reading a tf.Variable with tf.VariableSynchronization.ON_READ all-reduces the value. Therefore it's recommended to run exactly the same program on each worker. Dispatching based on task_type or task_id of the worker is error-prone.

cluster_resolver.num_accelerators() determines the number of GPUs the strategy uses. If it's zero, the strategy uses the CPU. All workers need to use the same number of devices, otherwise the behavior is undefined.

This strategy is not intended for TPU. Use tf.distribute.TPUStrategy instead.

After setting up TF_CONFIG, using this strategy is similar to using tf.distribute.MirroredStrategy and tf.distribute.TPUStrategy.

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(5,)),
  optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)

def dataset_fn(ctx):
  x = np.random.random((2, 5)).astype(np.float32)
  y = np.random.randint(2, size=(2, 1))
  dataset = tf.data.Dataset.from_tensor_slices((x, y))
  return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)


You can also write your own training loop:

def train_step(iterator):

  def step_fn(inputs):
    features, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(features, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

  strategy.run(step_fn, args=(next(iterator),))

for _ in range(NUM_STEP):

See Multi-worker training with Keras for a detailed tutorial.


You need to save and checkpoint on all workers instead of just one. This is because variables whose synchronization=ON_READ triggers aggregation during saving. It's recommended to save to a different path on each worker to avoid race conditions. Each worker saves the same thing. See Multi-worker training with Keras tutorial for examples.

Known Issues

communication optional tf.distribute.experimental.CommunicationImplementation. This is a hint on the preferred collective communication implementation. Possible values include AUTO, RING, and NCCL.
cluster_resolver optional tf.distribute.cluster_resolver.ClusterResolver. If None, tf.distribute.cluster_resolver.TFConfigClusterResolver is used.

cluster_resolver Returns the cluster resolver associated with this strategy.

As a multi-worker strategy, tf.distribute.MultiWorkerMirroredStrategy provides the associated tf.distribute.cluster_resolver.ClusterResolver. If the user provides one in __init__, that instance is returned; if the user does not, a default TFConfigClusterResolver is provided.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.



View source

Distributes tf.data.Dataset instances created by calls to dataset_fn.

The argument dataset_fn that users pass in is an input function that has a tf.distribute.InputContext argument and returns a tf.data.Dataset instance. It is expected that the returned dataset from dataset_fn is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. tf.distribute.Strategy.distribute_datasets_from_function does not batch or shard the tf.data.Dataset instance returned from the input function. dataset_fn will be called on the CPU device of each of the workers and each generates a dataset where every replica on that worker will dequeue one batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued from the Dataset every step).

This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, tf.distribute.experimental_distribute_dataset does batching and sharding for you.) For example, where experimental_distribute_dataset is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in experimental_distribute_dataset). In cases where the dataset is infinite, this sharding can be done by creating dataset replicas that differ only in their random seed.

The dataset_fn should take an tf.distribute.InputContext instance where information about batching and input replication can be accessed.

You can use element_spec property of the tf.distribute.DistributedDataset returned by this API to query the tf.TypeSpec of the elements returned by the iterator. This can be used to set the input_signature property of a tf.function. Follow tf.distribute.DistributedDataset.element_spec to see an example.

For a tutorial on more usage and properties of this method, refer to the tutorial on distributed input). If you are interested in last partial batch handling, read this section.

dataset_fn A function taking a tf.distribute.InputContext instance and