Save the date! Google I/O returns May 18-20 Register now


A one-machine strategy that puts all variables on a single device.

Inherits From: Strategy

Used in the notebooks

Used in the guide

Variables are assigned to local CPU or the only GPU. If there is more than one GPU, compute operations (other than variable update operations) will be replicated across all GPUs.

For Example:

strategy = tf.distribute.experimental.CentralStorageStrategy()
# Create a dataset
ds =
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(ds)

with strategy.scope():
  def train_step(val):
    return val + 1

  # Iterate over the distributed dataset
  for x in dist_dataset:
    # process dataset elements, args=(x,))

cluster_resolver Returns the cluster resolver associated with this strategy.

In general, when using a multi-worker tf.distribute strategy such as tf.distribute.experimental.MultiWorkerMirroredStrategy or tf.distribute.TPUStrategy(), there is a tf.distribute.cluster_resolver.ClusterResolver associated with the strategy used, and such an instance is returned by this property.

Strategies that intend to have an associated tf.distribute.cluster_resolver.ClusterResolver must set the relevant attribute, or override this property; otherwise, None is returned by default. Those strategies should also provide information regarding what is returned by this property.

Single-worker strategies usually do not have a tf.distribute.cluster_resolver.ClusterResolver, and in those cases this property will return None.

The tf.distribute.cluster_resolver.ClusterResolver may be useful when the user needs to access information such as the cluster spec, task type or task id. For example,

os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"],
'ps': ["localhost:34567"]
'task': {'type': 'worker', 'index': 0}

# This implicitly uses TF_CONFIG for the cluster and current task info.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()


if strategy.cluster_resolver.task_type == 'worker':
# Perform something that's only applicable on workers. Since we set this
# as a worker above, this block will run on this particular instance.
elif strategy.cluster_resolver.task_type == 'ps':
# Perform something that's only applicable on parameter servers. Since we
# set this as a worker above, this block will not run on this particular
# instance.

For more information, please see tf.distribute.cluster_resolver.ClusterResolver's API docstring.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.



View source

Distributes instances created by calls to dataset_fn.

The argument dataset_fn that users pass in is an input function that has a tf.distribute.InputContext argument and returns a instance. It is expected that the returned dataset from dataset_fn is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. tf.distribute.Strategy.distribute_datasets_from_function does not batch or shard the instance returned from the input function. dataset_fn will be called on the CPU device of each of the workers and each generates a dataset where every replica on that worker will dequeue one batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued from the Dataset every step).

This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, tf.distribute.experimental_distribute_dataset does batching and sharding for you.) For example, where experimental_distribute_dataset is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in experimental_distribute_dataset). In cases where the dataset is infinite, this sharding can be done by creating dataset replicas that differ only in their random seed.

The dataset_fn should take an tf.distribute.InputContext instance where information about batching and input replication can be accessed.

You can use element_spec property of the tf.distribute.DistributedDataset returned by this API to query the tf.TypeSpec of the elements returned by the iterator. This can be used to set the input_signature property of a tf.function. Follow tf.distribute.DistributedDataset.element_spec to see an example.

For a tutorial on more usage and properties of this method, refer to the tutorial on distributed input). If you are interested in last partial batch handling, read this section.

dataset_fn A function taking a tf.distribute.InputContext instance and returning a
options tf.distribute.InputOptions used to control options on how this dataset is distributed.

A tf.distribute.DistributedDataset.


View source

Distributes a instance provided via dataset.

The returned dataset is a wrapped strategy dataset which creates a multidevice iterator under the hood. It prefetches the input data to the specified devices on the worker. The returned distributed dataset can be iterated over similar to how regular datasets can.

For Example:

strategy = tf.distribute.CentralStorageStrategy()  # with 1 CPU and 1 GPU
dataset =
dist_dataset = strategy.experimental_distribute_dataset(dataset)
for x in dist_dataset:
  print(x)  # Prints PerReplica values [0, 1], [2, 3],...

Args: dataset: to be prefetched to device. options: tf.distribute.InputOptions used to control options on how this dataset is distributed.

A "distributed Dataset" that the caller can iterate over.


View source

Generates tf.distribute.DistributedValues from value_fn.

This function is to generate tf.distribute.DistributedValues to pass into run, reduce, or other methods that take distributed values when not using datasets.

value_fn The function to run to generate values. It is called for each replica with tf.distribute.ValueContext as the sole argument. It must return a Tensor or a type that can be converted to a Tensor.

A tf.distribute.DistributedValues containing a value for each replica.

Example usage:

  1. Return constant value per replica:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def value_fn(ctx):
  return tf.constant(1.)
distributed_values = (
local_result = strategy.experimental_local_results(distributed_values)
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
  1. Distribute values in array based on replica_id:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
array_value = np.array([3., 2., 1.])
def value_fn(ctx):
  return array_value[ctx.replica_id_in_sync_group]
distributed_values = (
local_result = strategy.experimental_local_results(distributed_values)
(3.0, 2.0)
  1. Specify values using num_replicas_in_sync:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def value_fn(ctx):
  return ctx.num_replicas_in_sync
distributed_values = (
local_result = strategy.experimental_local_results(distributed_values)
(2, 2)
  1. Place values on devices and distribute:
strategy = tf.distribute.TPUStrategy()
worker_devices = strategy.extended.worker_devices
multiple_values = []
for i in range(strategy.num_replicas_in_sync):
  with tf.device(worker_devices[i]):

def value_fn(ctx):
  return multiple_values[ctx.replica_id_in_sync_group]

distributed_values = strategy.


View source

Returns the list of all local per-replica values contained in value.

In CentralStorageStrategy there is a single worker so the value returned will be all the values on that worker.

value A value returned by run(), extended.call_for_each_replica(), or a variable created in scope.

A tuple of values contained in value. If value represents a single value, this returns (value,).


View source

Gather value across replicas along axis to the current device.

Given a tf.distribute.DistributedValues or tf.Tensor-like object value, this API gathers and concatenates value across replicas along the axis-th dimension. The result is copied to the "current" device

  • which would typically be the CPU of the worker on which the program is running. For tf.distribute.TPUStrategy, it is the first TPU host. For multi-client MultiWorkerMirroredStrategy, this is CPU of each worker.

This API can only be called in the cross-replica context. For a counterpart in the replica context, see tf.distribute.ReplicaContext.all_gather.