A state & compute distribution policy on a list of devices.

See the guide for overview and examples. See tf.distribute.StrategyExtended and tf.distribute for a glossory of concepts mentioned on this page such as "per-replica", replica, and reduce.

In short:

A custom training loop can be as simple as:

with my_strategy.scope():
  def distribute_train_epoch(dataset):
    def replica_fn(input):
      # process input and return result
      return result

    total_result = 0
    for x in dataset:
      per_replica_result =, args=(x,))
      total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
                                         per_replica_result, axis=None)
    return total_result

  dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
  for _ in range(EPOCHS):
    train_result = distribute_train_epoch(dist_dataset)

This takes an ordinary dataset and replica_fn and runs it distributed using a particular tf.distribute.Strategy named my_strategy above. Any variables created in replica_fn are created using my_strategy's policy, and library functions called by replica_fn can use the get_replica_context() API to implement distributed-specific behavior.

You can use the reduce API to aggregate results across replicas and use this as a return value from one iteration over a tf.distribute.DistributedDataset. Or you can use tf.keras.metrics (such as loss, accuracy, etc.) to accumulate metrics across steps in a given epoch.

See the custom training loop tutorial for a more detailed example.

cluster_resolver Returns the cluster resolver associated with this strategy.

In general, when using a multi-worker tf.distribute strategy such as tf.distribute.experimental.MultiWorkerMirroredStrategy or tf.distribute.experimental.TPUStrategy(), there is a tf.distribute.cluster_resolver.ClusterResolver associated with the strategy used, and such an instance is returned by this property.

Strategies that intend to have an associated tf.distribute.cluster_resolver.ClusterResolver must set the relevant attribute, or override this property; otherwise, None is returned by default. Those strategies should also provide information regarding what is returned by this property.

Single-worker strategies usually do not have a tf.distribute.cluster_resolver.ClusterResolver, and in those cases this property will return None.

The tf.distribute.cluster_resolver.ClusterResolver may be useful when the user needs to access information such as the cluster spec, task type or task id. For example,

os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"],
'ps': ["localhost:34567"]
'task': {'type': 'worker', 'index': 0}

# This implicitly uses TF_CONFIG for the cluster and current task info.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()


if strategy.cluster_resolver.task_type == 'worker':
# Perform something that's only applicable on workers. Since we set this
# as a worker above, this block will run on this particular instance.
elif strategy.cluster_resolver.task_type == 'ps':
# Perform something that's only applicable on parameter servers. Since we
# set this as a worker above, this block will not run on this particular
# instance.

For more information, please see tf.distribute.cluster_resolver.ClusterResolver's API docstring.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.



View source

Adds annotation that tensor will be assigned to a logical device.

# Initializing TPU system with 2 logical devices and 4 replicas.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment =
    computation_shape=[1, 1, 1, 2],
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)

def step_fn(inputs):
  output = tf.add(inputs, inputs)

  # Add operation will be executed on logical device 0.
  output = strategy.experimental_assign_to_logical_device(output, 0)
  return output, args=(next(iterator),))

tensor Input tensor to annotate.
logical_device_id Id of the logical core to which the tensor will be assigned.

ValueError The logical device id presented is not consistent with total number of partitions specified by the device assignment.

Annotated tensor with idential value as tensor.


View source

Creates tf.distribute.DistributedDataset from

The returned tf.distribute.DistributedDataset can be iterated over similar to how regular datasets can. NOTE: The user cannot add any more transformations to a tf.distribute.DistributedDataset.

The following is an example:

strategy = tf.distribute.MirroredStrategy()

# Create a dataset
dataset = dataset_ops.Dataset.TFRecordDataset([
  "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"])

# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)

# Iterate ov