Save the date! Google I/O returns May 18-20 Register now


An iterator over tf.distribute.DistributedDataset.

tf.distribute.DistributedIterator is the primary mechanism for enumerating elements of a tf.distribute.DistributedDataset. It supports the Python Iterator protocol, which means it can be iterated over using a for-loop or by fetching individual elements explicitly via get_next().

You can create a tf.distribute.DistributedIterator by calling iter on a tf.distribute.DistributedDataset or creating a python loop over a tf.distribute.DistributedDataset.

Visit the tutorial on distributed input for more examples and caveats.

element_spec The type specification of an element of tf.distribute.DistributedIterator.

global_batch_size = 16
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset =[1.],[2])).repeat(100).batch(global_batch_size)
distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))



View source

Returns the next input from the iterator for all replicas.

Example use:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset =
dist_dataset = strategy.experimental_distribute_dataset(dataset)
dist_dataset_iterator = iter(dist_dataset)
def one_step(input):
  return input
step_num = 5
for _ in range(step_num):, args=(dist_dataset_iterator.get_next(),))
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
 <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)

A single tf.Tensor or a tf.distribute.DistributedValues which contains the next input for all replicas.

tf.errors.OutOfRangeError: If the end of the iterator has been reached.


View source

Returns a tf.experimental.Optional that contains the next value for all replicas.

If the tf.distribute.DistributedIterator has reached the end of the sequence, the returned tf.experimental.Optional will have no value.

Example usage:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
global_batch_size = 2
steps_per_loop = 2
dataset =
distributed_iterator = iter(
def step_fn(x):
  # train the model with inputs
  return x
def train_fn(distributed_iterator):
  for _ in tf.range(steps_per_loop):
    optional_data = distributed_iterator.get_next_as_optional()
    if not optional_data.has_value():
    per_replica_results =, args=(optional_data.get_value(),))
# ([0 1], [2 3])
# ([4], [])

An tf.experimental.Optional object representing the next value from the tf.distribute.DistributedIterator (if it has one) or no value.