View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
TFDS has always been framework-agnostic. For instance, you can easily load datasets in NumPy format for usage in Jax and PyTorch.
TensorFlow and its data loading solution
(tf.data
) are first-class citizens in
our API by design.
We extended TFDS to support TensorFlow-less NumPy-only data loading. This can be convenient for usage in ML frameworks such as Jax and PyTorch. Indeed, for the latter users, TensorFlow can:
- reserve GPU/TPU memory;
- increase build time in CI/CD;
- take time to import at runtime.
TensorFlow is no longer a dependency to read datasets.
ML pipelines need a data loader to load examples, decode them, and present them to the model. Data loaders use the "source/sampler/loader" paradigm:
TFDS dataset ┌────────────────┐
on disk │ │
┌──────────►│ Data │
|..|... │ | │ source ├─┐
├──┼────┴─────┤ │ │ │
│12│image12 │ └────────────────┘ │ ┌────────────────┐
├──┼──────────┤ │ │ │
│13│image13 │ ├───►│ Data ├───► ML pipeline
├──┼──────────┤ │ │ loader │
│14│image14 │ ┌────────────────┐ │ │ │
├──┼──────────┤ │ │ │ └────────────────┘
|..|... | │ Index ├─┘
│ sampler │
│ │
└────────────────┘
- The data source is responsible for accessing and decoding examples from a TFDS dataset on the fly.
- The index sampler is responsible for determining the order in which records are processed. This is important to implement global transformations (e.g., global shuffling, sharding, repeating for multiple epochs) before reading any records.
- The data loader orchestrates the loading by leveraging the data source and the index sampler. It allows performance optimization (e.g., pre-fetching, multiprocessing or multithreading).
TL;DR
tfds.data_source
is an API to create data sources:
- for fast prototyping in pure-Python pipelines;
- to manage data-intensive ML pipelines at scale.
Setup
Let's install and import the needed dependencies:
!pip install array_record
!pip install grain-nightly
!pip install jax jaxlib
!pip install tfds-nightly
import os
os.environ.pop('TFDS_DATA_DIR', None)
import tensorflow_datasets as tfds
Data sources
Data sources are basically Python sequences. So they need to implement the following protocol:
from typing import SupportsIndex
class RandomAccessDataSource(Protocol):
"""Interface for datasources where storage supports efficient random access."""
def __len__(self) -> int:
"""Number of records in the dataset."""
def __getitem__(self, key: SupportsIndex) -> Any:
"""Retrieves the record for the given key."""
The underlying file format needs to support efficient random access. At the
moment, TFDS relies on array_record
.
array_record
is a new file format
derived from Riegeli, achieving a new
frontier of IO efficiency. In particular, ArrayRecord supports parallel read,
write, and random access by record index. ArrayRecord builds on top of Riegeli
and supports the same compression algorithms.
fashion_mnist
is
a common dataset for computer vision. To retrieve an ArrayRecord-based data
source with TFDS, simply use:
ds = tfds.data_source('fashion_mnist')
Downloading and preparing dataset 29.45 MiB (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1... 2024-04-26 11:20:57.419076: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Dataset fashion_mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.
tfds.data_source
is a convenient wrapper. It is equivalent to:
builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()
This outputs a dictionary of data sources:
{
'train': DataSource(name=fashion_mnist, split='train', decoders=None),
'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}
Once download_and_prepare
has run, and you generated the record files, we
don't need TensorFlow anymore. Everything will happen in Python/NumPy!
Let's check this by uninstalling TensorFlow and re-loading the data source in another subprocess:
pip uninstall -y tensorflow
/usr/lib/python3.9/pty.py:85: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)
import tensorflow_datasets as tfds
try:
import tensorflow as tf
except ImportError:
print('No TensorFlow found...')
ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
Writing no_tensorflow.py
python no_tensorflow.py
No TensorFlow found... ...but the data source could still be loaded... WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images... ...and the records can be decoded.
In future versions, we are also going to make the dataset preparation TensorFlow-free.
A data source has a length:
len(ds['train'])
60000
Accessing the first element of the dataset:
%%timeit
ds['train'][0]
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images... 584 µs ± 2.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
...is just as cheap as accessing any other element. This is the definition of random access:
%%timeit
ds['train'][1000]
581 µs ± 2.33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Features now use NumPy DTypes (rather than TensorFlow DTypes). You can inspect the features with:
features = tfds.builder('fashion_mnist').info.features
You'll find more information about the features in our documentation. Here we can notably retrieve the shape of the images, and the number of classes:
shape = features['image'].shape
num_classes = features['label'].num_classes
Use in pure Python
You can consume data sources in Python by iterating over them:
for example in ds['train']:
print(example)
break
{'image': array([[[ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 18], [ 77], [227], [227], [208], [210], [225], [216], [ 85], [ 32], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 61], [100], [ 97], [ 80], [ 57], [117], [227], [238], [115], [ 49], [ 78], [106], [108], [ 71], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 81], [105], [ 80], [ 69], [ 72], [ 64], [ 44], [ 21], [ 13], [ 44], [ 69], [ 75], [ 75], [ 80], [114], [ 80], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 26], [ 92], [ 69], [ 68], [ 75], [ 75], [ 71], [ 74], [ 83], [ 75], [ 77], [ 78], [ 74], [ 74], [ 83], [ 77], [108], [ 34], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 55], [ 92], [ 69], [ 74], [ 74], [ 71], [ 71], [ 77], [ 69], [ 66], [ 75], [ 74], [ 77], [ 80], [ 80], [ 78], [ 94], [ 63], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 63], [ 95], [ 66], [ 68], [ 72], [ 72], [ 69], [ 72], [ 74], [ 74], [ 74], [ 75], [ 75], [ 77], [ 80], [ 77], [106], [ 61], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 80], [108], [ 71], [ 69], [ 72], [ 71], [ 69], [ 72], [ 75], [ 75], [ 72], [ 72], [ 75], [ 78], [ 72], [ 85], [128], [ 64], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 88], [120], [ 75], [ 74], [ 77], [ 75], [ 72], [ 77], [ 74], [ 74], [ 77], [ 78], [ 83], [ 83], [ 66], [111], [123], [ 78], [ 0], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 0], [ 85], [134], [ 74], [ 85], [ 69], [ 75], [ 75], [ 74], [ 75], [ 74], [ 75], [ 75], [ 81], [ 75], [ 61], [151], [115], [ 91], [ 12], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 10], [ 85], [153], [ 83], [ 80], [ 68], [ 77], [ 75], [ 74], [ 75], [ 74], [ 75], [ 77], [ 80], [ 68], [ 61], [162], [122], [ 78], [ 6], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 30], [ 75], [154], [ 85], [ 80], [ 71], [ 80], [ 72], [ 77], [ 75], [ 75], [ 77], [ 78], [ 77], [ 75], [ 49], [191], [132], [ 72], [ 15], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 58], [ 66], [174], [115], [ 66], [ 77], [ 80], [ 72], [ 78], [ 75], [ 77], [ 78], [ 78], [ 77], [ 66], [ 49], [222], [131], [ 77], [ 37], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 69], [ 55], [179], [139], [ 55], [ 92], [ 74], [ 74], [ 78], [ 74], [ 78], [ 77], [ 75], [ 80], [ 64], [ 55], [242], [111], [ 95], [ 44], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 74], [ 57], [159], [180], [ 55], [ 92], [ 64], [ 72], [ 74], [ 74], [ 77], [ 75], [ 77], [ 78], [ 55], [ 66], [255], [ 97], [108], [ 49], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 74], [ 66], [145], [153], [ 72], [ 83], [ 58], [ 78], [ 77], [ 75], [ 75], [ 75], [ 72], [ 80], [ 30], [132], [255], [ 37], [122], [ 60], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 80], [ 69], [142], [180], [142], [ 57], [ 64], [ 78], [ 74], [ 75], [ 75], [ 75], [ 72], [ 85], [ 21], [185], [227], [ 37], [143], [ 63], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 0], [ 83], [ 71], [136], [194], [126], [ 46], [ 69], [ 75], [ 72], [ 75], [ 75], [ 75], [ 74], [ 78], [ 38], [139], [185], [ 60], [151], [ 58], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 4], [ 81], [ 74], [145], [177], [ 78], [ 49], [ 74], [ 77], [ 75], [ 75], [ 75], [ 75], [ 74], [ 72], [ 63], [ 80], [156], [117], [153], [ 55], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 10], [ 80], [ 72], [157], [163], [ 61], [ 55], [ 75], [ 77], [ 75], [ 77], [ 75], [ 75], [ 75], [ 77], [ 71], [ 60], [ 98], [156], [132], [ 58], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 13], [ 77], [ 74], [157], [143], [ 43], [ 61], [ 72], [ 75], [ 77], [ 75], [ 74], [ 77], [ 77], [ 75], [ 71], [ 58], [ 80], [157], [120], [ 66], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 18], [ 81], [ 74], [156], [114], [ 35], [ 72], [ 71], [ 75], [ 78], [ 72], [ 66], [ 80], [ 78], [ 77], [ 75], [ 64], [ 63], [165], [119], [ 68], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 23], [ 85], [ 81], [177], [ 57], [ 52], [ 77], [ 71], [ 78], [ 80], [ 72], [ 75], [ 74], [ 77], [ 77], [ 75], [ 64], [ 37], [173], [ 95], [ 72], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 26], [ 81], [ 86], [160], [ 20], [ 75], [ 77], [ 77], [ 80], [ 78], [ 80], [ 89], [ 78], [ 81], [ 83], [ 80], [ 74], [ 20], [177], [ 77], [ 74], [ 0], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 49], [ 77], [ 91], [200], [ 0], [ 83], [ 95], [ 86], [ 88], [ 88], [ 89], [ 88], [ 89], [ 88], [ 83], [ 89], [ 86], [ 0], [191], [ 78], [ 80], [ 24], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 54], [ 71], [108], [165], [ 0], [ 24], [ 57], [ 52], [ 57], [ 60], [ 60], [ 60], [ 63], [ 63], [ 77], [ 89], [ 52], [ 0], [211], [ 97], [ 77], [ 61], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 68], [ 91], [117], [137], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 18], [216], [ 94], [ 97], [ 57], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 54], [115], [105], [185], [ 0], [ 0], [ 1], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [153], [ 78], [106], [ 37], [ 0], [ 0], [ 0]], [[ 0], [ 0], [ 0], [ 18], [ 61], [ 41], [103], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [ 0], [106], [ 47], [ 69], [ 23], [ 0], [ 0], [ 0]]], dtype=uint8), 'label': 2}
If you inspect elements, you will also notice that all features are already decoded using NumPy. Behind the scenes, we use OpenCV by default because it is fast. If you don't have OpenCV installed, we default to Pillow to provide lightweight and fast image decoding.
{
'image': array([[[0], [0], ..., [0]],
[[0], [0], ..., [0]]], dtype=uint8),
'label': 2,
}
Use with PyTorch
PyTorch uses the source/sampler/loader paradigm. In Torch, "data sources" are
called "datasets".
torch.utils.data
contains all the
details you need to know to build efficient input pipelines in Torch.
TFDS data sources can be used as regular map-style datasets.
First we install and import Torch:
!pip install torch
from tqdm import tqdm
import torch
We already defined data sources for training and testing (respectively,
ds['train']
and ds['test']
). We can now define the sampler and the loaders:
batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
ds['train'],
sampler=train_sampler,
batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
ds['test'],
sampler=None,
batch_size=batch_size,
)
Using PyTorch, we train and evaluate a simple logistic regression on the first examples:
class LinearClassifier(torch.nn.Module):
def __init__(self, shape, num_classes):
super(LinearClassifier, self).__init__()
height, width, channels = shape
self.classifier = torch.nn.Linear(height * width * channels, num_classes)
def forward(self, image):
image = image.view(image.size()[0], -1).to(torch.float32)
return self.classifier(image)
model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()
print('Training...')
model.train()
for example in tqdm(train_loader):
image, label = example['image'], example['label']
prediction = model(image)
loss = loss_function(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
image, label = example['image'], example['label']
prediction = model(image)
num_examples += image.shape[0]
predicted_label = prediction.argmax(dim=1)
true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML warnings.warn("Can't initialize NVML") Training... 100%|██████████| 40/40 [00:01<00:00, 31.22it/s] Testing... 100%|██████████| 79/79 [00:02<00:00, 32.66it/s] Accuracy: 65.63%
Use with JAX
Grain is a library for reading data for
training and evaluating JAX models. It's open source, fast and deterministic.
Grain uses the source/sampler/loader paradigm, so we can re-use
tfds.data_source
:
import grain.python as pygrain
import numpy as np
data_source = tfds.data_source("fashion_mnist", split="train")
# To shuffle the data, use a sampler:
sampler = pygrain.IndexSampler(
num_records=5,
num_epochs=1,
shard_options=pygrain.NoSharding(),
shuffle=True,
seed=0,
)
Transformations are defined as classes and can be BatchTransform
,
FilterTransform
or MapTransform
:
class ImageToText(pygrain.MapTransform):
"""Maps an image to text."""
LABEL_TO_TEXT = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "height",
9: "nine",
}
def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
label = element["label"]
text = self.LABEL_TO_TEXT[label]
element["text"] = text
return element
# You can chain transformations in a list:
operations = [ImageToText()]
Finally, the data loader takes care of orchestrating the loading. You can scale up with multiprocessing to enjoy both the flexibility of Python and the performance of a data loader:
loader = pygrain.DataLoader(
data_source=data_source,
operations=operations,
sampler=sampler,
worker_count=0, # Scale to multiple workers in multiprocessing
)
for element in loader:
print(element["text"])
two one one height four
Read more
For more information, please refer to tfds.data_source
API doc.