このページは Cloud Translation API によって翻訳されました。
Switch to English

TensorFlowデータセット

TFDSは、TensorFlow、Jax、およびその他の機械学習フレームワークで使用するためのすぐに使用できるデータセットのコレクションを提供します。

データをダウンロードして確定的に準備し、 tf.data.Dataset (またはnp.array )を構築します。

TensorFlow.orgで表示 Google Colabで実行 GitHubでソースを表示

取り付け

TFDSは次の2つのパッケージに含まれています。

  • pip install tensorflow-datasets :数か月ごとにリリースされる安定版。
  • pip install tfds-nightly :毎日リリースされ、データセットの最新バージョンが含まれています。

このcolabはtfds-nightly使用します:

pip install -q tfds-nightly tensorflow matplotlib
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

利用可能なデータセットを見つける

すべてのデータセットビルダーは、 tfds.core.DatasetBuilderサブクラスです。利用可能なビルダーのリストを取得するには、 tfds.list_builders()使用するか、 カタログを確認してください。

tfds.list_builders()
['abstract_reasoning',
 'aeslc',
 'aflw2k3d',
 'ag_news_subset',
 'ai2_arc',
 'ai2_arc_with_ir',
 'amazon_us_reviews',
 'anli',
 'arc',
 'bair_robot_pushing_small',
 'bccd',
 'beans',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'blimp',
 'bool_q',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'cfq',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'clic',
 'clinc_oos',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coco_captions',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'common_voice',
 'coqa',
 'cos_e',
 'cosmos_qa',
 'covid19sum',
 'crema_d',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'deep_weeds',
 'definite_pronoun_resolution',
 'dementiabank',
 'diabetic_retinopathy_detection',
 'div2k',
 'dmlab',
 'downsampled_imagenet',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'emnist',
 'eraser_multi_rc',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'forest_fires',
 'fuss',
 'gap',
 'geirhos_conflict_stimuli',
 'genomics_ood',
 'german_credit_numeric',
 'gigaword',
 'glue',
 'goemotions',
 'gpt3',
 'groove',
 'gtzan',
 'gtzan_music_speech',
 'higgs',
 'horses_or_humans',
 'i_naturalist2017',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet2012_real',
 'imagenet2012_subset',
 'imagenet_a',
 'imagenet_r',
 'imagenet_resized',
 'imagenet_v2',
 'imagenette',
 'imagewang',
 'imdb_reviews',
 'irc_disentanglement',
 'iris',
 'kitti',
 'kmnist',
 'lfw',
 'librispeech',
 'librispeech_lm',
 'libritts',
 'ljspeech',
 'lm1b',
 'lost_and_found',
 'lsun',
 'malaria',
 'math_dataset',
 'mctaco',
 'mnist',
 'mnist_corrupted',
 'movie_lens',
 'movie_rationales',
 'movielens',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'natural_questions',
 'newsroom',
 'nsynth',
 'nyu_depth_v2',
 'omniglot',
 'open_images_challenge2019_detection',
 'open_images_v4',
 'openbookqa',
 'opinion_abstracts',
 'opinosis',
 'opus',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'paws_wiki',
 'paws_x_wiki',
 'pet_finder',
 'pg19',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'qa4mre',
 'qasc',
 'quickdraw_bitmap',
 'radon',
 'reddit',
 'reddit_disentanglement',
 'reddit_tifu',
 'resisc45',
 'robonet',
 'rock_paper_scissors',
 'rock_you',
 'salient_span_wikipedia',
 'samsum',
 'savee',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'sentiment140',
 'shapes3d',
 'smallnorb',
 'snli',
 'so2sat',
 'speech_commands',
 'spoken_digit',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'stl10',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tedlium',
 'tf_flowers',
 'the300w_lp',
 'tiny_shakespeare',
 'titanic',
 'trec',
 'trivia_qa',
 'tydi_qa',
 'uc_merced',
 'ucf101',
 'vctk',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'voxceleb',
 'voxforge',
 'waymo_open_dataset',
 'web_questions',
 'wider_face',
 'wiki40b',
 'wikihow',
 'wikipedia',
 'wikipedia_toxicity_subtypes',
 'wine_quality',
 'winogrande',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'wordnet',
 'xnli',
 'xquad',
 'xsum',
 'yelp_polarity_reviews',
 'yes_no']

データセットを読み込む

データセットをロードする最も簡単な方法は、 tfds.loadです。そうなる:

  1. データをダウンロードし、 tfrecordファイルとして保存します。
  2. tfrecordをロードしてtfrecordを作成しtf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

いくつかの一般的な引数:

  • split= :読み取るスプリット(例: 'train'['train', 'test']'train[80%:]' 、...)。 分割APIガイドをご覧ください。
  • shuffle_files= :各エポック間でファイルをシャッフルするかどうかを制御します(TFDSは大きなデータセットを複数の小さなファイルに格納します)。
  • data_dir= :データセットが保存される場所(デフォルトは~/tensorflow_datasets/
  • with_info=True :データセットのメタデータを含むtfds.core.DatasetInfoを返します
  • download=False :ダウンロードを無効にします

tfds.loadは、 tfds.loadの薄いラッパーtfds.core.DatasetBuildertfds.core.DatasetBuilder APIを使用して同じ出力を取得できます。

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

データセットを反復する

辞書として

デフォルトでは、 tf.data.Datasetオブジェクトにはtf.Tensordictが含まれています。

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)
['image', 'label']
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

タプルとして( as_supervised=True

as_supervised=Trueを使用すると、監視対象データセットの代わりにタプル(features, label)取得できます。

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

numpyとして( tfds.as_numpy

tfds.as_numpyを使用して変換します。

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4

バッチ化されたtf.Tensorとして( batch_size=-1

batch_size=-1を使用すると、単一のバッチで完全なデータセットをロードできます。

tfds.load返されますdicttupleas_supervised=Trueの) tf.Tensornp.arraytfds.as_numpy )。

データセットがメモリに収まり、すべての例が同じ形状になるように注意してください。

image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test', 
    batch_size=-1, 
    as_supervised=True,
))

print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)

エンドツーエンドのパイプラインを構築する

さらに進むには、次のようにします。

可視化

tfds.as_dataframe

tf.data.Datasetオブジェクトは、 pandas.DataFrametfds.as_dataframeしてpandas.DataFrameに変換し、 pandas.DataFrameで視覚化できます。

  • 追加tfds.core.DatasetInfo 2番目の引数としてtfds.as_dataframe画像、音声、テキスト、ビデオを、視覚化します...
  • 最初のx例のみを表示するには、 ds.take(x)を使用します。 pandas.DataFrameは完全なデータセットをメモリ内に読み込み、表示に非常に負荷がかかる可能性があります。
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

tfds.show_examples

tfds.show_examplestfds.show_examples画像のtfds.show_examples (現在サポートされているのは画像データセットのみ):

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

png

データセットのメタデータにアクセスする

すべてのビルダーには、データセットメタデータを含むtfds.core.DatasetInfoオブジェクトが含まれてtfds.core.DatasetInfoます。

次の方法でアクセスできます。

ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info

データセット情報には、データセットに関する追加情報(バージョン、引用、ホームページ、説明など)が含まれています。

print(info)
tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)


機能メタデータ(ラベル名、画像形状など)

tfds.features.FeatureDictアクセスしtfds.features.FeatureDict

info.features
FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

クラス数、ラベル名:

print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
7
7

形状、dtype:

print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()}
{'image': tf.uint8, 'label': tf.int64}
(28, 28, 1)
<dtype: 'uint8'>

メタデータの分割(例:名前の分割、例の数など)

tfds.core.SplitDictアクセスしtfds.core.SplitDict

print(info.splits)
{'test': <tfds.core.SplitInfo num_examples=10000>, 'train': <tfds.core.SplitInfo num_examples=60000>}

利用可能な分割:

print(list(info.splits.keys()))
['test', 'train']

個別の分割に関する情報を取得します。

print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000
['mnist-train.tfrecord-00000-of-00001']
1

サブスプリットAPIでも機能します。

print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000
[FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]

引用

論文にtensorflow-datasetsを使用している場合は、使用したデータセットに固有の引用( データセットカタログにあります )に加えて、次の引用を含めてください。

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets} },
}