TensorFlow 数据集

TensorFlow Datasets 提供了一系列可以和 TensorFlow 配合使用的数据集。它负责下载和准备数据,以及构建 tf.data.Dataset

在 tensorflow.google.cn 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载 notebook

安装

pip install tensorflow-datasets

注意使用 tensorflow-datasets 的前提是已经安装好 TensorFlow,目前支持的版本是 tensorflow (或者 tensorflow-gpu) >= 1.15.0

pip install -q tensorflow tensorflow-datasets matplotlib
import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literals

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Eager execution

TensorFlow 数据集同时兼容 TensorFlow Eager 模式 和图模式。在这个 colab 环境里面,我们的代码将通过 Eager 模式执行。

tf.compat.v1.enable_eager_execution()

列出可用的数据集

每一个数据集(dataset)都实现了抽象基类 tfds.core.DatasetBuilder 来构建。你可以通过 tfds.list_builders() 列出所有可用的数据集。

你可以通过 数据集文档页 查看所有支持的数据集及其补充文档。

tfds.list_builders()
['abstract_reasoning',
 'aeslc',
 'aflw2k3d',
 'amazon_us_reviews',
 'arc',
 'bair_robot_pushing_small',
 'beans',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'cfq',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'cos_e',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'deep_weeds',
 'definite_pronoun_resolution',
 'diabetic_retinopathy_detection',
 'div2k',
 'dmlab',
 'downsampled_imagenet',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'dummy_dataset_shared_generator',
 'dummy_mnist',
 'emnist',
 'eraser_multi_rc',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'gap',
 'gigaword',
 'glue',
 'groove',
 'higgs',
 'horses_or_humans',
 'i_naturalist2017',
 'image_label_folder',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet_resized',
 'imagenette',
 'imagewang',
 'imdb_reviews',
 'iris',
 'kitti',
 'kmnist',
 'lfw',
 'librispeech',
 'librispeech_lm',
 'libritts',
 'lm1b',
 'lost_and_found',
 'lsun',
 'malaria',
 'math_dataset',
 'mnist',
 'mnist_corrupted',
 'movie_rationales',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'natural_questions',
 'newsroom',
 'nsynth',
 'omniglot',
 'open_images_v4',
 'opinosis',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'pet_finder',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'qa4mre',
 'quickdraw_bitmap',
 'reddit_tifu',
 'resisc45',
 'rock_paper_scissors',
 'rock_you',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'shapes3d',
 'smallnorb',
 'snli',
 'so2sat',
 'speech_commands',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tf_flowers',
 'the300w_lp',
 'tiny_shakespeare',
 'titanic',
 'trivia_qa',
 'uc_merced',
 'ucf101',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'wider_face',
 'wikihow',
 'wikipedia',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'xnli',
 'xsum',
 'yelp_polarity_reviews']

tfds.load: 一行代码获取数据集

tfds.load 是构建并加载 tf.data.Dataset 最简单的方式。

tf.data.Dataset 是构建输入流水线的标准 TensorFlow 接口。如果你对这个接口不熟悉,我们强烈建议你阅读 TensorFlow 官方指南

下面,我们先加载 MNIST 训练数据。这个步骤会下载并准备好该数据,除非你显式指定 download=False。值得注意的是,一旦该数据准备好了,后续的 load 命令便不会重新下载,可以重复使用准备好的数据。 你可以通过指定 data_dir= (默认是 ~/tensorflow_datasets/) 来自定义数据保存/加载的路径。

mnist_train = tfds.load(name="mnist", split="train")
assert isinstance(mnist_train, tf.data.Dataset)
print(mnist_train)
Downloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.0...

WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.
<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

加载数据集时,将使用规范的默认版本。 但是,建议你指定要使用的数据集的主版本,并在结果中表明使用了哪个版本的数据集。更多详情请参考数据集版本控制文档

mnist = tfds.load("mnist:1.*.*")
WARNING:absl:Found a different version 3.0.0 of dataset mnist in data_dir /home/kbuilder/tensorflow_datasets. Using currently defined version 1.0.0.

Downloading and preparing dataset mnist/1.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...

WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=13.0, style=ProgressStyle(descripti…


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

特征字典

所有 tfds 数据集都包含将特征名称映射到 Tensor 值的特征字典。 典型的数据集(如 MNIST)将具有2个键:"image""label"。 下面我们看一个例子。

注意:在图模式(graph mode)下,请参阅 tf.data 指南 以了解如何在 tf.data.Dataset 上进行迭代。

for mnist_example in mnist_train.take(1):  # 只取一个样本
    image, label = mnist_example["image"], mnist_example["label"]

    plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
    print("Label: %d" % label.numpy())
Label: 4

png

DatasetBuilder

tfds.load 实际上是一个基于 DatasetBuilder 的简单方便的包装器。我们可以直接使用 MNIST DatasetBuilder 实现与上述相同的操作。

mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
mnist_train = mnist_builder.as_dataset(split="train")
mnist_train
<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

输入流水线

一旦有了 tf.data.Dataset 对象,就可以使用 tf.data 接口 定义适合模型训练的输入流水线的其余部分。

在这里,我们将重复使用数据集,以便有无限的样本流,然后随机打乱并创建大小为32的批。

mnist_train = mnist_train.repeat().shuffle(1024).batch(32)

# prefetch 将使输入流水线可以在模型训练时异步获取批处理。
mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)

# 现在你可以遍历数据集的批次并在 mnist_train 中训练批次:
#   ...

数据集信息(DatasetInfo)

生成后,构建器将包含有关数据集的有用信息:

info = mnist_builder.info
print(info)
tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.0,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

DatasetInfo 还包含特征相关的有用信息:

print(info.features)
print(info.features["label"].num_classes)
print(info.features["label"].names)
FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

你也可以通过 tfds.load 使用 with_info = True 加载 DatasetInfo

mnist_test, info = tfds.load("mnist", split="test", with_info=True)
print(info)
tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.0,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

可视化

对于图像分类数据集,你可以使用 tfds.show_examples 来可视化一些样本。

fig = tfds.show_examples(info, mnist_test)

png