ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


A distribution strategy for synchronous training on multiple workers.

Inherits From: MultiWorkerMirroredStrategy, Strategy

Used in the notebooks

Used in the tutorials

This strategy implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to tf.distribute.MirroredStrategy, it replicates all variables and computations to each local device. The difference is that it uses a distributed collective implementation (e.g. all-reduce), so that multiple workers can work together.

You need to launch your program on each worker and configure cluster_resolver correctly. For example, if you are using tf.distribute.cluster_resolver.TFConfigClusterResolver, each worker needs to have its corresponding task_type and task_id set in the TF_CONFIG environment variable. An example TF_CONFIG on worker-0 of a two worker cluster is:

TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

Your program runs on each worker as-is. Note that collectives require each worker to participate. All tf.distribute and non tf.distribute API may use collectives internally, e.g. checkpointing and saving since reading a tf.Variable with tf.VariableSynchronization.ON_READ all-reduces the value. Therefore it's recommended to run exactly the same program on each worker. Dispatching based on task_type or task_id of the worker is error-prone.

cluster_resolver.num_accelerators() determines the number of GPUs the strategy uses. If it's zero, the strategy uses the CPU. All workers need to use the same number of devices, otherwise the behavior is undefined.

This strategy is not intended for TPU. Use tf.distribute.TPUStrategy instead.

After setting up TF_CONFIG, using this strategy is similar to using tf.distribute.MirroredStrategy and tf.distribute.TPUStrategy.

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
  model = tf.keras.Sequential([