TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.
It handles downloading and preparing the data deterministically and constructing a tf.data.Dataset
(or np.array
).
![]() |
![]() |
![]() |
Installation
TFDS exists in two packages:
pip install tensorflow-datasets
: The stable version, released every few months.pip install tfds-nightly
: Released every day, contains the last versions of the datasets.
This colab uses tfds-nightly
:
pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
Find available datasets
All dataset builders are subclass of tfds.core.DatasetBuilder
. To get the list of available builders, use tfds.list_builders()
or look at our catalog.
tfds.list_builders()
['abstract_reasoning', 'accentdb', 'aeslc', 'aflw2k3d', 'ag_news_subset', 'ai2_arc', 'ai2_arc_with_ir', 'amazon_us_reviews', 'anli', 'arc', 'bair_robot_pushing_small', 'bccd', 'beans', 'big_patent', 'bigearthnet', 'billsum', 'binarized_mnist', 'binary_alpha_digits', 'blimp', 'bool_q', 'c4', 'caltech101', 'caltech_birds2010', 'caltech_birds2011', 'cars196', 'cassava', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'cfq', 'cherry_blossoms', 'chexpert', 'cifar10', 'cifar100', 'cifar10_1', 'cifar10_corrupted', 'citrus_leaves', 'cityscapes', 'civil_comments', 'clevr', 'clic', 'clinc_oos', 'cmaterdb', 'cnn_dailymail', 'coco', 'coco_captions', 'coil100', 'colorectal_histology', 'colorectal_histology_large', 'common_voice', 'coqa', 'cos_e', 'cosmos_qa', 'covid19sum', 'crema_d', 'curated_breast_imaging_ddsm', 'cycle_gan', 'dart', 'davis', 'deep_weeds', 'definite_pronoun_resolution', 'dementiabank', 'diabetic_retinopathy_detection', 'div2k', 'dmlab', 'downsampled_imagenet', 'drop', 'dsprites', 'dtd', 'duke_ultrasound', 'e2e_cleaned', 'emnist', 'eraser_multi_rc', 'esnli', 'eurosat', 'fashion_mnist', 'flic', 'flores', 'food101', 'forest_fires', 'fuss', 'gap', 'geirhos_conflict_stimuli', 'genomics_ood', 'german_credit_numeric', 'gigaword', 'glue', 'goemotions', 'gpt3', 'gref', 'groove', 'gtzan', 'gtzan_music_speech', 'hellaswag', 'higgs', 'horses_or_humans', 'howell', 'i_naturalist2017', 'imagenet2012', 'imagenet2012_corrupted', 'imagenet2012_real', 'imagenet2012_subset', 'imagenet_a', 'imagenet_r', 'imagenet_resized', 'imagenet_v2', 'imagenette', 'imagewang', 'imdb_reviews', 'irc_disentanglement', 'iris', 'kitti', 'kmnist', 'lambada', 'lfw', 'librispeech', 'librispeech_lm', 'libritts', 'ljspeech', 'lm1b', 'lost_and_found', 'lsun', 'lvis', 'malaria', 'math_dataset', 'mctaco', 'mlqa', 'mnist', 'mnist_corrupted', 'movie_lens', 'movie_rationales', 'movielens', 'moving_mnist', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'natural_questions', 'natural_questions_open', 'newsroom', 'nsynth', 'nyu_depth_v2', 'omniglot', 'open_images_challenge2019_detection', 'open_images_v4', 'openbookqa', 'opinion_abstracts', 'opinosis', 'opus', 'oxford_flowers102', 'oxford_iiit_pet', 'para_crawl', 'patch_camelyon', 'paws_wiki', 'paws_x_wiki', 'pet_finder', 'pg19', 'piqa', 'places365_small', 'plant_leaves', 'plant_village', 'plantae_k', 'qa4mre', 'qasc', 'quac', 'quickdraw_bitmap', 'race', 'radon', 'reddit', 'reddit_disentanglement', 'reddit_tifu', 'resisc45', 'robonet', 'rock_paper_scissors', 'rock_you', 's3o4d', 'salient_span_wikipedia', 'samsum', 'savee', 'scan', 'scene_parse150', 'scicite', 'scientific_papers', 'sentiment140', 'shapes3d', 'siscore', 'smallnorb', 'snli', 'so2sat', 'speech_commands', 'spoken_digit', 'squad', 'stanford_dogs', 'stanford_online_products', 'starcraft_video', 'stl10', 'story_cloze', 'sun397', 'super_glue', 'svhn_cropped', 'ted_hrlr_translate', 'ted_multi_translate', 'tedlium', 'tf_flowers', 'the300w_lp', 'tiny_shakespeare', 'titanic', 'trec', 'trivia_qa', 'tydi_qa', 'uc_merced', 'ucf101', 'vctk', 'vgg_face2', 'visual_domain_decathlon', 'voc', 'voxceleb', 'voxforge', 'waymo_open_dataset', 'web_nlg', 'web_questions', 'wider_face', 'wiki40b', 'wiki_bio', 'wiki_table_questions', 'wiki_table_text', 'wikihow', 'wikipedia', 'wikipedia_toxicity_subtypes', 'wine_quality', 'winogrande', 'wmt14_translate', 'wmt15_translate', 'wmt16_translate', 'wmt17_translate', 'wmt18_translate', 'wmt19_translate', 'wmt_t2t_translate', 'wmt_translate', 'wordnet', 'wsc273', 'xnli', 'xquad', 'xsum', 'xtreme_pawsx', 'xtreme_xnli', 'yelp_polarity_reviews', 'yes_no']
Load a dataset
tfds.load
The easiest way of loading a dataset is tfds.load
. It will:
- Download the data and save it as
tfrecord
files. - Load the
tfrecord
and create thetf.data.Dataset
.
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data. <_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
Some common arguments:
split=
: Which split to read (e.g.'train'
,['train', 'test']
,'train[80%:]'
,...). See our split API guide.shuffle_files=
: Control whether to shuffle the files between each epoch (TFDS store big datasets in multiple smaller files).data_dir=
: Location where the dataset is saved ( defaults to~/tensorflow_datasets/
)with_info=True
: Returns thetfds.core.DatasetInfo
containing dataset metadatadownload=False
: Disable download
tfds.builder
tfds.load
is a thin wrapper around tfds.core.DatasetBuilder
. You can get the same output using the tfds.core.DatasetBuilder
API:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
tfds build
CLI
If you want to generate a specific dataset, you can use the tfds
command line. For example:
tfds build mnist
See the doc for available flags.
Iterate over a dataset
As dict
By default, the tf.data.Dataset
object contains a dict
of tf.Tensor
s:
ds = tfds.load('mnist', split='train')
ds = ds.take(1) # Only take a single example
for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
print(list(example.keys()))
image = example["image"]
label = example["label"]
print(image.shape, label)
['image', 'label'] (28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
To find out the dict
key names and structure, look at the dataset documentation in our catalog. For example: mnist documentation.
As tuple (as_supervised=True
)
By using as_supervised=True
, you can get a tuple (features, label)
instead for supervised datasets.
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in ds: # example is (image, label)
print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
As numpy (tfds.as_numpy
)
Uses tfds.as_numpy
to convert:
tf.Tensor
->np.array
tf.data.Dataset
->Iterator[Tree[np.array]]
(Tree
can be arbitrary nestedDict
,Tuple
)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in tfds.as_numpy(ds):
print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4
As batched tf.Tensor (batch_size=-1
)
By using batch_size=-1
, you can load the full dataset in a single batch.
This can be combined with as_supervised=True
and tfds.as_numpy
to get the the data as (np.array, np.array)
:
image, label = tfds.as_numpy(tfds.load(
'mnist',
split='test',
batch_size=-1,
as_supervised=True,
))
print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)
Be careful that your dataset can fit in memory, and that all examples have the same shape.
Build end-to-end pipeline
To go further, you can look:
- Our end-to-end Keras example to see a full training pipeline (with batching, shuffling,...).
- Our performance guide to improve the speed of your pipelines.
Visualization
tfds.as_dataframe
tf.data.Dataset
objects can be converted to pandas.DataFrame
with tfds.as_dataframe
to be visualized on Colab.
- Add the
tfds.core.DatasetInfo
as second argument oftfds.as_dataframe
to visualize images, audio, texts, videos,... - Use
ds.take(x)
to only display the firstx
examples.pandas.DataFrame
will load the full dataset in-memory, and can be very expensive to display.
ds, info = tfds.load('mnist', split='train', with_info=True)
tfds.as_dataframe(ds.take(4), info)
tfds.show_examples
tfds.show_examples
returns a matplotlib.figure.Figure
(only image datasets supported now):
ds, info = tfds.load('mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, info)
Access the dataset metadata
All builders include a tfds.core.DatasetInfo
object containing the dataset metadata.
It can be accessed through:
- The
tfds.load
API:
ds, info = tfds.load('mnist', with_info=True)
- The
tfds.core.DatasetBuilder
API:
builder = tfds.builder('mnist')
info = builder.info
The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).
print(info)
tfds.core.DatasetInfo( name='mnist', full_name='mnist/3.0.1', description=""" The MNIST database of handwritten digits. """, homepage='http://yann.lecun.com/exdb/mnist/', data_path='/home/kbuilder/tensorflow_datasets/mnist/3.0.1', download_size=11.06 MiB, dataset_size=21.00 MiB, features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), supervised_keys=('image', 'label'), splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>, }, citation="""@article{lecun2010mnist, title={MNIST handwritten digit database}, author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist}, volume={2}, year={2010} }""", )
Features metadata (label names, image shape,...)
Access the tfds.features.FeatureDict
:
info.features
FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })
Number of classes, label names:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7)) # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 7 7
Shapes, dtypes:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()} {'image': tf.uint8, 'label': tf.int64} (28, 28, 1) <dtype: 'uint8'>
Split metadata (e.g. split names, number of examples,...)
Access the tfds.core.SplitDict
:
print(info.splits)
{'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>}
Available splits:
print(list(info.splits.keys()))
['test', 'train']
Get info on individual split:
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000 ['mnist-train.tfrecord-00000-of-00001'] 1
It also works with the subsplit API:
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000 [FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]
Troubleshooting
Manual download (if download fails)
If download fails for some reason (e.g. offline,...). You can always manually download the data yourself and place it in the manual_dir
(defaults to ~/tensorflow_datasets/download/manual/
.
To find out which urls to download, look into:
For new datasets (implemented as folder):
tensorflow_datasets/
<type>/<dataset_name>/checksums.tsv
. For example:tensorflow_datasets/text/bool_q/checksums.tsv
.You can find the dataset source location in our catalog.
For old datasets:
tensorflow_datasets/url_checksums/<dataset_name>.txt
Fixing NonMatchingChecksumError
TFDS ensure determinism by validating the checksums of downloaded urls.
If NonMatchingChecksumError
is raised, might indicate:
- The website may be down (e.g.
503 status code
). Please check the url. - For Google Drive URLs, try again later as Drive sometimes rejects downloads when too many people access the same URL. See bug
- The original datasets files may have been updated. In this case the TFDS dataset builder should be updated. Please open a new Github issue or PR:
- Register the new checksums with
tfds build --register_checksums
- Eventually update the dataset generation code.
- Update the dataset
VERSION
- Update the dataset
RELEASE_NOTES
: What caused the checksums to change ? Did some examples changed ? - Make sure the dataset can still be built.
- Send us a PR
- Register the new checksums with
Citation
If you're using tensorflow-datasets
for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the dataset catalog).
@misc{TFDS,
title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
howpublished = {\url{https://www.tensorflow.org/datasets} },
}