A distribution strategy for synchronous training on multiple workers.

Inherits From: Strategy

Used in the notebooks

Used in the guide Used in the tutorials

This strategy implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to tf.distribute.MirroredStrategy, it creates copies of all variables in the model on each device across all workers.

It uses CollectiveOps's implementation of multi-worker all-reduce to to keep variables in sync. A collective op is a single op in the TensorFlow graph which can automatically choose an all-reduce algorithm in the TensorFlow runtime according to hardware, network topology and tensor sizes.

By default it uses all local GPUs or CPU for single-worker training.

When 'TF_CONFIG' environment variable is set, it parses cluster_spec, task_type and task_id from 'TF_CONFIG' and turns into a multi-worker strategy which mirrored models on GPUs of all machines in a cluster. In the current implementation, it uses all GPUs in a cluster and it assumes all workers have the same number of GPUs.

You can also pass a distribute.cluster_resolver.ClusterResolver instance when instantiating the strategy. The task_type, task_id etc. will be parsed from the resolver instance instead of from the TF_CONFIG env var.

It supports both eager mode and graph mode. However, for eager mode, it has to set up the eager context in its constructor and therefore all ops in eager mode have to run after the strategy object is created.

communication optional Enum of type distribute.experimental.CollectiveCommunication. This provides a way for the user to override the choice of collective op communication. Possible values include AUTO, RING, and NCCL.
cluster_resolver optional distribute.cluster_resolver.ClusterResolver object. The default ClusterResolver that is used is the TFConfigClusterResolver which is instantiated from the TF_CONFIG env var.

cluster_resolver Returns the cluster resolver associated with this strategy.

As a multi-worker strategy, tf.distribute.experimental.MultiWorkerMirroredStrategy provides the associated tf.distribute.cluster_resolver.ClusterResolver. If the user provides one in __init__, that instance is returned; if the user does not, a default TFConfigClusterResolver is provided.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.



View source

Adds annotation that tensor will be assigned to a logical device.

# Initializing TPU system with 2 logical devices and 4 replicas.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment =
    computation_shape=[1, 1, 1, 2],
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)

def step_fn(inputs):
  output = tf.add(inputs, inputs)

  # Add operation will be executed on logical device 0.
  output = strategy.experimental_assign_to_logical_device(output, 0)
  return output, args=(next(iterator),))

tensor Input tensor to annotate.
logical_device_id Id of the logical core to which the tensor will be assigned.

ValueError The logical device id presented is not consistent with total number of partitions specified by the device assignment.

Annotated tensor with idential value as tensor.


View source

Creates tf.distribute.DistributedDataset from

The returned tf.distribute.DistributedDataset can be iterated over similar to how regular datasets can. NOTE: The user cannot add any more transformations to a tf.distribute.DistributedDataset.

The following is an example:

strategy = tf.distribute.MirroredStrategy()

# Create a dataset
dataset = dataset_ops.Dataset.TFRecordDataset([
  "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"])

# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)

# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
  # process dataset elements, args=(x,))

In the code snippet above, the tf.distribute.DistributedDataset dist_dataset is batched by GLOBAL_BATCH_SIZE, and we iterate through it using for x in dist_dataset. x a tf.distribute.DistributedValues containing data for all replicas, which aggregates to a batch of GLOBAL_BATCH_SIZE. will take care of feeding the right per-replica data in x to the right replica_fn executed on each replica.

What's under the hood of this method, when we say the instance - dataset - gets distributed? It depends on how you set the through By default, it is set to In a multi-worker setting, we will first attempt to distribute dataset by detecting whether dataset is being created out of reader datasets (e.g.,, etc.) and if so, try to shard the input files. Note that there has to be at least one input file per worker. If you have less than one input file per worker, we suggest that you disable dataset sharding across workers, by setting the to be