tf.distribute.DistributedIterator

An iterator over tf.distribute.DistributedDataset.

tf.distribute.DistributedIterator is the primary mechanism for enumerating elements of a tf.distribute.DistributedDataset. It supports the Python Iterator protocol, which means it can be iterated over using a for-loop or by fetching individual elements explicitly via get_next().

You can create a tf.distribute.DistributedIterator by calling iter on a tf.distribute.DistributedDataset or creating a python loop over a tf.distribute.DistributedDataset.

Visit the tutorial on distributed input for more examples and caveats.

element_spec The type specification of an element of tf.distribute.DistributedIterator.

global_batch_size = 16
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_iterator.element_spec
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))

The above example corresponds to the case where you have only one device. If you have two devices, for example,

strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])

Then the final line will print out:

(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))

Methods

get_next

View source

Returns the next input from the iterator for all replicas.

Example use:

strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.range(100).batch(2)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
dist_dataset_iterator = iter(dist_dataset)
@tf.function
def one_step(input):
  return input
step_num = 5