Google I / O hoạt động trở lại từ ngày 18 đến 20 tháng 5! Đặt chỗ và xây dựng lịch trình của bạn Đăng ký ngay


A distribution strategy for synchronous training on multiple workers.

Inherits From: MultiWorkerMirroredStrategy, Strategy

Used in the notebooks

Used in the tutorials

This strategy implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to tf.distribute.MirroredStrategy, it replicates all variables and computations to each local device. The difference is that it uses a distributed collective implementation (e.g. all-reduce), so that multiple workers can work together.

You need to launch your program on each worker and configure cluster_resolver correctly. For example, if you are using tf.distribute.cluster_resolver.TFConfigClusterResolver, each worker needs to have its corresponding task_type and task_id set in the TF_CONFIG environment variable. An example TF_CONFIG on worker-0 of a two worker cluster is:

TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

Your program runs on each worker as-is. Note that collectives require each worker to participate. All tf.distribute and non tf.distribute API may use collectives internally, e.g. checkpointing and saving since reading a tf.Variable with tf.VariableSynchronization.ON_READ all-reduces the value. Therefore it's recommended to run exactly the same program on each worker. Dispatching based on task_type or task_id of the worker is error-prone.

cluster_resolver.num_accelerators() determines the number of GPUs the strategy uses. If it's zero, the strategy uses the CPU. All workers need to use the same number of devices, otherwise the behavior is undefined.

This strategy is not intended for TPU. Use tf.distribute.TPUStrategy instead.

After setting up TF_CONFIG, using this strategy is similar to using tf.distribute.MirroredStrategy and tf.distribute.TPUStrategy.

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(5,)),
  optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)

def dataset_fn(ctx):
  x = np.random.random((2, 5)).astype(np.float32)
  y = np.random.randint(2, size=(2, 1))
  dataset =, y))
  return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)


You can also write your own training loop:

def train_step(iterator):

  def step_fn(inputs):
    features, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(features, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables)), args=(next(iterator),))

for _ in range(NUM_STEP):

See Multi-worker training with Keras for a detailed tutorial.


You need to save and checkpoint on all workers instead of just one. This is because variables whose synchronization=ON_READ triggers aggregation during saving. It's recommended to save to a different path on each worker to avoid race conditions. Each worker saves the same thing. See Multi-worker training with Keras tutorial for examples.

Known Issues

communication optional tf.distribute.experimental.CommunicationImplementation. This is a hint on the preferred collective communication implementation. Possible values include AUTO, RING, and NCCL.
cluster_resolver optional tf.distribute.cluster_resolver.ClusterResolver. If None, tf.distribute.cluster_resolver.TFConfigClusterResolver is used.

cluster_resolver Returns the cluster resolver associated with this strategy.

As a multi-worker strategy, tf.distribute.MultiWorkerMirroredStrategy provides the associated tf.distribute.cluster_resolver.ClusterResolver. If the user provides one in __init__, that instance is returned; if the user does not, a default TFConfigClusterResolver is provided.

extended tf.distribute.StrategyExtended with additional methods.
num_replicas_in_sync Returns number of replicas over which gradients are aggregated.



View source

Distributes instances created by calls to dataset_fn.

The argument dataset_fn that users pass in is an input function that has a tf.distribute.InputContext argument and returns a instance. It is expected that the returned dataset from dataset_fn is already batched by per-replica batch size (i.e. global batch size divided by the number of replicas in sync) and sharded. tf.distribute.Strategy.distribute_datasets_from_function does not batch or shard the instance returned from the input function. dataset_fn will be called on the CPU device of each of the workers and each generates a dataset where every replica on that worker will dequeue one batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued from the Dataset every step).

This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, tf.distribute.experimental_distribute_dataset does batching and sharding for you.) For example, where experimental_distribute_dataset is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in experimental_distribute_dataset). In cases where the dataset is infinite, this sharding can be done by creating dataset replicas that differ only in their random seed.

The dataset_fn should take an tf.distribute.InputContext instance where information about batching and input replication can be accessed.

You can use element_spec property of the tf.distribute.DistributedDataset returned by this API to query the tf.TypeSpec of the elements returned by the iterator. This can be used to set the input_signature property of a tf.function. Follow tf.distribute.DistributedDataset.element_spec to see an example.

For a tutorial on more usage and properties of this method, refer to the tutorial on distributed input). If you are interested in last partial batch handling, read this section.

dataset_fn A function taking a tf.distribute.InputContext instance and returning a
options tf.distribute.InputOptions used to control options on how this dataset is distributed.

A tf.distribute.DistributedDataset.


View source

Creates tf.distribute.DistributedDataset from

The returned tf.distribute.DistributedDataset can be iterated over similar to regular datasets. NOTE: The user cannot add any more transformations to a tf.distribute.DistributedDataset. You can only create an iterator or examine the tf.TypeSpec of the data generated by it. See API docs of tf.distribute.DistributedDataset to learn more.

The following is an example:

global_batch_size = 2
# Passing the devices is optional.
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
# Create a dataset
dataset =
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
def replica_fn(input):
  return input*2
result = []
# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
  # process dataset elements
  result.append(, args=(x,)))
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>

Three key actions happening under the hood of this method are batching, sharding, and prefetching.

In the code snippet above, dataset is batched by global_batch_size, and calling experimental_distribute_dataset on it rebatches dataset to a new batch size that is equal to the global batch size divided by the number of replicas in sync. We iterate through it using a Pythonic for loop. x is a tf.distribute.DistributedValues containing data for all replicas, and each replica gets data of the new batch size. will take care of feeding the right per-replica data in x to the right replica_fn executed on each replica.

Sharding contains autosharding across multiple workers and within every worker. First, in multi-worker distributed training (i.e. when you use tf.distribute.experimental.MultiWorkerMirroredStrategy or tf.distribute.TPUStrategy), autosharding a dataset over a set of workers means that each worker is assigned a subset of the entire dataset (if the right is set). This is to ensure that at each step, a global batch size of non-overlapping dataset elements will be processed by each worker. Autosharding has a couple of different options that can be specified using Then, sharding within each worker means the method will split the data among all the worker devices (if more than one a present). This will happen regardless of multi-worker autosharding.

By default, this method adds a prefetch transformation at the end of the user provided instance. The argument to the prefetch transformation which is buffer_size is equal to the number of replicas in sync.

If the above batch splitting and dataset sharding logic is undesirable, please use tf.distribute.Strategy.distribute_datasets_from_function instead, which does not do any automatic batching or sharding for you.

For a tutorial on more usage and properties of this method, refer to the tutorial on distributed input. If you are interested in last partial batch handling, read this section.

dataset that will be sharded across all replicas using the rules stated above.
options tf.distribute.InputOptions used to control options on how this dataset is distributed.

A tf.distribute.DistributedDataset.


View source

Generates tf.distribute.DistributedValues from value_fn.

This function is to generate tf.distribute.DistributedValues to pass into run, reduce, or other methods that take distributed values when not using datasets.

value_fn The function to run to generate values. It is called for each replica with tf.distribute.Value