ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

TFDS and determinism

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

This document explains:

  • The TFDS guarantees on determinism
  • In which order does TFDS read examples
  • Various caveats and gotchas

Setup

Datasets

Some context is needed to understand how TFDS reads the data.

During generation, TFDS write the original data into standardized .tfrecord files. For big datasets, multiple .tfrecord files are created, each containing multiple examples. We call each .tfrecord file a shard.

This guide uses imagenet which has 1024 shards:

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

Finding the dataset examples ids

You can skip to the following section if you only want to know about determinism.

Each dataset example is uniquely identified by an id (e.g. 'imagenet2012-train.tfrecord-01023-of-01024__32'). You can recover this id by passing read_config.add_tfds_id = True which will add a 'tfds_id' key in the dict from the tf.data.Dataset.

In this tutorial, we define a small util which will print the example ids of the dataset (converted in integer to be more human-readable):

Determinism when reading

This section explains deterministim guarantee of tfds.load.

With shuffle_files=False (default)

By default TFDS yield examples deterministically (shuffle_files=False)

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

For performance, TFDS read multiple shards at the same time using tf.data.Dataset.interleave. We see in this example that TFDS switch to shard 2 after reading 16 examples (..., 14, 15, 1251, 1252, ...). More on interleave bellow.

Similarly, the subsplit API is also deterministic:

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

If you're training for more than one epoch, the above setup is not recommended as all epochs will read the shards in the same order (so randomness is limited to the ds = ds.shuffle(buffer) buffer size).

With shuffle_files=True

With shuffle_files=True, shards are shuffled for each epoch, so reading is not deterministic anymore.

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

See recipe below to get deterministic file shuffling.

Determinism caveat: interleave args

Changing read_config.interleave_cycle_length, read_config.interleave_block_length will change the examples order.

TFDS relies on tf.data.Dataset.interleave to only load a few shards at once, improving the performance and reducing memory usage.

The example order is only guaranteed to be the same for a fixed value of interleave args. See interleave doc to understand what cycle_length and block_length correspond too.

  • cycle_length=16, block_length=16 (default, same as above):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3, block_length=2:
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

In the second example, we see that the dataset read 2 (block_length=2) examples in a shard, then switch to the next shard. Every 2 * 3 (cycle_length=3) examples, it goes back to the first shard (shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,...).

Subsplit and example order

Each example has an id 0, 1, ..., num_examples-1. The subsplit API select a slice of examples (e.g. train[:x] select 0, 1, ..., x-1).

However, within the subsplit, examples are not read in increasing id order (due to shards and interleave).

More specifically, ds.take(x) and split='train[:x]' are not equivalent !

This can be seen easily in the above interleave example where examples come from different shards.

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

After the 16 (block_length) examples, .take(25) switches to the next shard while train[:25] continue reading examples in from the first shard.

Recipes

Get deterministic file shuffling

There are 2 ways to have deterministic shuffling:

  1. Setting the shuffle_seed. Note: This requires changing the seed at each epoch, otherwise shards will be read in the same order between epoch.
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. Using experimental_interleave_sort_fn: This gives full control over which shards are read and in which order, rather than relying on ds.shuffle order.
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

Get deterministic preemptable pipeline

This one is more complicated. There is no easy, satisfactory solution.

  1. Without ds.shuffle and with deterministic shuffling, in theory it should be possible to count the examples which have been read and deduce which examples have been read within in each shard (as a function of cycle_length, block_length and shard order). Then the skip, take for each shard could be injected through experimental_interleave_sort_fn.

  2. With ds.shuffle it's likely impossible without replaying the full training pipeline. It would require saving the ds.shuffle buffer state to deduce which examples have been read. Examples could be non-continuous (e.g. shard5_ex2, shard5_ex4 read but not shard5_ex3).

  3. With ds.shuffle, one way would be to save all shards_ids/example_ids read (deduced from tfds_id), then deducing the file instructions from that.

The simplest case for 1. is to have .skip(x).take(y) match train[x:x+y] match. It requires:

  • Set cycle_length=1 (so shards are read sequentially)
  • Set shuffle_files=False
  • Do not use ds.shuffle

It should only be used on huge dataset where the training is only 1 epoch. Examples would be read in the default shuffle order.

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

Find which shards/examples are read for a given subsplit

With the tfds.core.DatasetInfo, you have direct access to the read instructions.

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]