Google I/O returns May 18-20! Reserve space and build your schedule Register now


View source on GitHub

Represents a dataset distributed among devices and machines.

A tf.distribute.DistributedDataset could be thought of as a "distributed" dataset. When you use tf.distribute API to scale training to multiple devices or machines, you also need to distribute the input data, which leads to a tf.distribute.DistributedDataset instance, instead of a instance in the non-distributed case. In TF 2.x, tf.distribute.DistributedDataset objects are Python iterables.

There are two APIs to create a tf.distribute.DistributedDataset object: tf.distribute.Strategy.experimental_distribute_dataset(dataset)and tf.distribute.Strategy.experimental_distribute_datasets_from_function(dataset_fn). When to use which? When you have a instance, and the regular batch splitting (i.e. re-batch the input instance with a new batch size that is equal to the global batch size divided by the number of replicas in sync) and autosharding (i.e. the options) work for you, use the former API. Otherwise, if you are not using a canonical instance, or you would like to customize the batch splitting or sharding, you can wrap these logic in a dataset_fn and use the latter API. Both API handles prefetch to device for the user. For more details and examples, follow the links to the APIs.

There are two main usages of a DistributedDataset object:

  1. Iterate over it to generate the input for a single device or multiple devices, which is a tf.distribute.DistributedValues instance. To do this, you can:

    • use a pythonic for-loop construct:
    global_batch_size = 2
    strategy = tf.distribute.MirroredStrategy()
    dataset =[1.],[1.])).repeat(4).batch(global_batch_size)
    dist_dataset = strategy.experimental_distribute_dataset(dataset)
    def train_step(input):
      features, labels = input
      return labels - 0.3 * features
    for x in dist_dataset:
      # train_step trains the model using the dataset elements
      loss =, args=(x,))
      print("Loss is", loss)
        Loss is tf.Tensor(
         [0.7]], shape=(2, 1), dtype=float32)
        Loss is tf.Tensor(
         [0.7]], shape=(2, 1), dtype=float32)
Placing the loop inside a <a href="../../tf/function"><code>tf.function</code></a> will give a performance boost.
However `break` and `return` are currently not supported if the loop is
placed inside a <a href="../../tf/function"><code>tf.function</code></a>. We also don't support placing the loop
inside a <a href="../../tf/function"><code>tf.function</code></a> when using
<a href="../../tf/distribute/experimental/MultiWorkerMirroredStrategy"><code>tf.distribute.experimental.MultiWorkerMirroredStrategy</code></a> or
<a href="../../tf/distribute/experimental/TPUStrategy"><code>tf.distribute.experimental.TPUStrategy</code></a> with multiple workers.
    global_batch_size = 4
    strategy = tf.distribute.MirroredStrategy()
    train_dataset =[1.],[1.])).repeat(50).batch(global_batch_size)
    train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
    def distributed_train_step(dataset_inputs):
      def train_step(input):
        loss = tf.constant(0.1)
        return loss
      per_replica_losses =, args=(dataset_inputs,))
      return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
    EPOCHS = 2
    STEPS = 3
    for epoch in range(EPOCHS):
      total_loss = 0.0
      num_batches = 0
      dist_dataset_iterator = iter(train_dist_dataset)
      for _ in range(STEPS):
        total_loss += distributed_train_step(next(dist_dataset_iterator))
        num_batches += 1
      average_train_loss = total_loss / num_batches
      template = ("Epoch {}, Loss: {}")
      print (template.format(epoch+1, average_train_loss))
        Epoch 1, Loss: 0.10000000894069672
        Epoch 2, Loss: 0.10000000894069672

To achieve a performance improvement, you can also wrap the call with a tf.range inside a tf.function. This runs multiple steps in a tf.function. Autograph will convert it to a tf.while_loop on the worker. However, it is less flexible comparing with running a single step inside tf.function. For example, you cannot run things eagerly or arbitrary python code within the steps.

  1. Inspect the tf.TypeSpec of the data generated by DistributedDataset.

    tf.distribute.DistributedDataset generates tf.distribute.DistributedValues as input to the devices. If you pass the input to a tf.function and would like to specify the shape and type of each Tensor argument to the function, you can pass a tf.TypeSpec object to the input_signature argument of the tf.function. To get the tf.TypeSpec of the input, you can use the element_spec property of the tf.distribute.DistributedDataset or tf.distribute.DistributedIterator object.

    For example:

  global_batch_size = 2
  epochs = 1
  steps_per_epoch = 1
  mirrored_strategy = tf.distribute.MirroredStrategy()
  dataset =[2.])).repeat(100).batch(global_batch_size)
  dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
  def train_step(per_replica_inputs):
    def step_fn(inputs):
      return tf.square(inputs)
    return, args=(per_replica_inputs,))
  for _ in range(epochs):
    iterator = iter(dist_dataset)
    for _ in range(steps_per_epoch):
      output = train_step(next(iterator))
     [4.]], shape=(2, 1), dtype=float32)

Visit the tutorial on distributed input for more examples and caveats.

element_spec The type specification of an element of this tf.distribute.DistributedDataset.

global_batch_size = 16
strategy = tf.distribute.MirroredStrategy()
dataset =[1.],[2])).repeat(100).batch(global_batch_size)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))

The above example corresponds to the case where you have only one device. If you have two devices, for example,

strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])

Then the final line will print out:

(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))



View source

Creates an iterator for the tf.distribute.DistributedDataset.

The returned iterator implements the Python Iterator protocol.

Example usage:

global_batch_size = 4
strategy = tf.distribute.MirroredStrategy()
dataset =[1, 2, 3, 4]).repeat().batch(global_batch_size)
distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)

The above example corresponds to the case where you have only one device. If you have two d