TensorFlow Datasets

TensorFlow Datasets provides a collection of datasets ready to use with TensorFlow. It handles downloading and preparing the data and constructing a tf.data.Dataset.

Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

View on TensorFlow.org Run in Google Colab View source on GitHub

Installation

pip install tensorflow-datasets

Note that tensorflow-datasets expects you to have TensorFlow already installed, and currently depends on tensorflow (or tensorflow-gpu) >= 1.12.0.

!pip install -q tensorflow tensorflow-datasets matplotlib
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Eager execution

TensorFlow Datasets is compatible with both TensorFlow Eager mode and Graph mode. For this colab, we'll run in Eager mode.

tf.enable_eager_execution()

List the available datasets

Each dataset is implemented as a tfds.core.DatasetBuilder and you can list all available builders with tfds.list_builders().

You can see all the datasets with additional documentation on the datasets documentation page.

tfds.list_builders()
['bair_robot_pushing_small',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'cifar10',
 'cifar100',
 'coco2014',
 'diabetic_retinopathy_detection',
 'dummy_dataset_shared_generator',
 'dummy_mnist',
 'fashion_mnist',
 'image_label_folder',
 'imagenet2012',
 'imdb_reviews',
 'lm1b',
 'lsun',
 'mnist',
 'moving_mnist',
 'nsynth',
 'omniglot',
 'open_images_v4',
 'quickdraw_bitmap',
 'squad',
 'starcraft_video',
 'svhn_cropped',
 'tf_flowers',
 'wmt_translate_ende',
 'wmt_translate_enfr']

tfds.load: A dataset in one line

tfds.load is a convenience method that's the simplest way to build and load and tf.data.Dataset.

Below, we load the MNIST training data. Setting download=True will download and prepare the data. Note that it's safe to call load multiple times with download=True as long as the builder name and data_dir remain the same. The prepared data will be reused.

mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)
mnist_train
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]
Downloading / extracting dataset mnist (11.06 MiB) to /root/tensorflow_datasets/mnist/1.0.0...

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...:   0%|          | 0/9 [00:00<?, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.67 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:   0%|          | 0/2 [00:00<?, ? file/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  3.30 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  10%|█         | 1/10 [00:00<00:05,  1.55 MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  3.30 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:03,  2.02 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:03,  2.02 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  40%|████      | 4/10 [00:00<00:02,  2.02 MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  3.30 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  50%|█████     | 5/10 [00:00<00:01,  2.78 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  60%|██████    | 6/10 [00:00<00:01,  2.78 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.67 url/s]
Dl Size...:  70%|███████   | 7/10 [00:00<00:01,  2.78 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.67 url/s]
Dl Size...:  80%|████████  | 8/10 [00:00<00:00,  2.78 MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  3.30 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.67 url/s]
Dl Size...:  90%|█████████ | 9/10 [00:01<00:00,  3.78 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.67 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.11 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.11 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.57 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.57 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Extraction completed...:  50%|█████     | 2/4 [00:01<00:00,  3.30 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.57 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Extraction completed...:  75%|███████▌  | 3/4 [00:01<00:00,  2.51 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.57 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.78 MiB/s]

Extraction completed...: 100%|██████████| 4/4 [00:01<00:00,  3.04 file/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.80 MiB/s]
51 examples [00:00, 508.64 examples/s]




60000 examples [00:15, 3817.62 examples/s]
Shuffling...:   0%|          | 0/10 [00:00<?, ? shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 313608.45 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 182932.38 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 320931.25 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  20%|██        | 2/10 [00:00<00:00, 16.72 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 316862.13 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 199922.34 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 273931.62 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  40%|████      | 4/10 [00:00<00:00, 16.72 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 335182.32 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 197307.83 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 294643.83 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  60%|██████    | 6/10 [00:00<00:00, 16.85 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 219499.39 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 176788.37 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 268550.04 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  80%|████████  | 8/10 [00:00<00:00, 16.37 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 354768.02 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 166716.29 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 267661.74 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 10/10 [00:00<00:00, 16.12 shard/s]
10000 examples [00:02, 3762.09 examples/s]
Shuffling...:   0%|          | 0/1 [00:00<?, ? shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 10000 examples [00:00, 413545.65 examples/s]
Writing...:   0%|          | 0/10000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 1/1 [00:00<00:00, 10.83 shard/s]

<_OptionsDataset shapes: {label: (), image: (28, 28, 1)}, types: {label: tf.int64, image: tf.uint8}>

Feature dictionaries

All tfds datasets contain feature dictionaries mapping feature names to Tensor values. A typical dataset, like MNIST, will have 2 keys: "image" and "label". Below we inspect a single example.

mnist_example, = mnist_train.take(1)
image, label = mnist_example["image"], mnist_example["label"]

plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
Label: 3

png

DatasetBuilder

tfds.load is really a thin conveninence wrapper around DatasetBuilder. We can accomplish the same as above directly with the MNIST DatasetBuilder.

mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
mnist_train = mnist_builder.as_dataset(split=tfds.Split.TRAIN)
mnist_train
<_OptionsDataset shapes: {label: (), image: (28, 28, 1)}, types: {label: tf.int64, image: tf.uint8}>

Input pipelines

Once you have a tf.data.Dataset object, it's simple to define the rest of an input pipeline suitable for model training by using the tf.data API.

Here we'll repeat the dataset so that we have an infinite stream of examples, shuffle, and create batches of 32.

mnist_train = mnist_train.repeat().shuffle(1024).batch(32)

# prefetch will enable the input pipeline to asynchronously fetch batches while
# your model is training.
mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)

# Now you could loop over batches of the dataset and train
# for batch in mnist_train:
#   ...

DatasetInfo

After generation, the builder contains useful information on the dataset:

info = mnist_builder.info
print(info)
tfds.core.DatasetInfo(
    name='mnist',
    version=1.0.0,
    description='The MNIST database of handwritten digits.',
    urls=['http://yann.lecun.com/exdb/mnist/'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)
    },
    total_num_examples=70000,
    splits={
        'test': <tfds.core.SplitInfo num_examples=10000>,
        'train': <tfds.core.SplitInfo num_examples=60000>
    },
    supervised_keys=('image', 'label'),
    citation='"""
        @article{lecun2010mnist,
          title={MNIST handwritten digit database},
          author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
          journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
          volume={2},
          year={2010}
        }
        
    """',
)

DatasetInfo also contains useful information about the features:

print(info.features)
print(info.features["label"].num_classes)
print(info.features["label"].names)
FeaturesDict({'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), 'image': Image(shape=(28, 28, 1), dtype=tf.uint8)})
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

You can also load the DatasetInfo directly with tfds.load using with_info=True.

dataset, info = tfds.load("mnist", split="test", with_info=True)
print(info)
tfds.core.DatasetInfo(
    name='mnist',
    version=1.0.0,
    description='The MNIST database of handwritten digits.',
    urls=['http://yann.lecun.com/exdb/mnist/'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10)
    },
    total_num_examples=70000,
    splits={
        'test': <tfds.core.SplitInfo num_examples=10000>,
        'train': <tfds.core.SplitInfo num_examples=60000>
    },
    supervised_keys=('image', 'label'),
    citation='"""
        @article{lecun2010mnist,
          title={MNIST handwritten digit database},
          author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
          journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
          volume={2},
          year={2010}
        }
        
    """',
)