|View source on GitHub|
Represents a dataset distributed among devices and machines.
tf.distribute.DistributedDataset could be thought of as a "distributed"
dataset. When you use
tf.distribute API to scale training to multiple
devices or machines, you also need to distribute the input data, which leads
tf.distribute.DistributedDataset instance, instead of a
tf.data.Dataset instance in the non-distributed case. In TF 2.x,
tf.distribute.DistributedDataset objects are Python iterables.
There are two APIs to create a
When to use which? When you have a
tf.data.Dataset instance, and the
regular batch splitting (i.e. re-batch the input
with a new batch size that is equal to the global batch size divided by the
number of replicas in sync) and autosharding (i.e. the
tf.data.experimental.AutoShardPolicy options) work for you, use the former
API. Otherwise, if you are not using a canonical
or you would like to customize the batch splitting or sharding, you can wrap
these logic in a
dataset_fn and use the latter API. Both API handles
prefetch to device for the user. For more details and examples, follow the
links to the APIs.
There are two main usages of a
Iterate over it to generate the input for a single device or multiple devices, which is a
tf.distribute.DistributedValuesinstance. To do this, you can:
- use a pythonic for-loop construct:
global_batch_size = 2