TFDSは、TensorFlow、Jax、およびその他の機械学習フレームワークで使用するためのすぐに使用できるデータセットのコレクションを提供します。
データのダウンロードと準備を決定論的に処理し、tf.data.Dataset
(またはnp.array
)を構築します。
![]() | ![]() | ![]() |
インストール
TFDSは2つのパッケージで存在します。
-
pip install tensorflow-datasets
:安定バージョン。数か月ごとにリリースされます。 -
pip install tfds-nightly
:毎日リリースされ、データセットの最新バージョンが含まれています。
このコラボはtfds-nightly
使用します:
pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
利用可能なデータセットを見つける
すべてのデータセットビルダーは、 tfds.core.DatasetBuilder
サブクラスです。利用可能なビルダーのリストを取得するには、 tfds.list_builders()
使用するか、カタログを参照してください。
tfds.list_builders()
['abstract_reasoning', 'accentdb', 'aeslc', 'aflw2k3d', 'ag_news_subset', 'ai2_arc', 'ai2_arc_with_ir', 'amazon_us_reviews', 'anli', 'arc', 'bair_robot_pushing_small', 'bccd', 'beans', 'big_patent', 'bigearthnet', 'billsum', 'binarized_mnist', 'binary_alpha_digits', 'blimp', 'bool_q', 'c4', 'caltech101', 'caltech_birds2010', 'caltech_birds2011', 'cars196', 'cassava', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'cfq', 'cherry_blossoms', 'chexpert', 'cifar10', 'cifar100', 'cifar10_1', 'cifar10_corrupted', 'citrus_leaves', 'cityscapes', 'civil_comments', 'clevr', 'clic', 'clinc_oos', 'cmaterdb', 'cnn_dailymail', 'coco', 'coco_captions', 'coil100', 'colorectal_histology', 'colorectal_histology_large', 'common_voice', 'coqa', 'cos_e', 'cosmos_qa', 'covid19sum', 'crema_d', 'curated_breast_imaging_ddsm', 'cycle_gan', 'dart', 'davis', 'deep_weeds', 'definite_pronoun_resolution', 'dementiabank', 'diabetic_retinopathy_detection', 'div2k', 'dmlab', 'downsampled_imagenet', 'drop', 'dsprites', 'dtd', 'duke_ultrasound', 'e2e_cleaned', 'emnist', 'eraser_multi_rc', 'esnli', 'eurosat', 'fashion_mnist', 'flic', 'flores', 'food101', 'forest_fires', 'fuss', 'gap', 'geirhos_conflict_stimuli', 'genomics_ood', 'german_credit_numeric', 'gigaword', 'glue', 'goemotions', 'gpt3', 'gref', 'groove', 'gtzan', 'gtzan_music_speech', 'hellaswag', 'higgs', 'horses_or_humans', 'howell', 'i_naturalist2017', 'imagenet2012', 'imagenet2012_corrupted', 'imagenet2012_real', 'imagenet2012_subset', 'imagenet_a', 'imagenet_r', 'imagenet_resized', 'imagenet_v2', 'imagenette', 'imagewang', 'imdb_reviews', 'irc_disentanglement', 'iris', 'kitti', 'kmnist', 'lambada', 'lfw', 'librispeech', 'librispeech_lm', 'libritts', 'ljspeech', 'lm1b', 'lost_and_found', 'lsun', 'lvis', 'malaria', 'math_dataset', 'mctaco', 'mlqa', 'mnist', 'mnist_corrupted', 'movie_lens', 'movie_rationales', 'movielens', 'moving_mnist', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'natural_questions', 'natural_questions_open', 'newsroom', 'nsynth', 'nyu_depth_v2', 'omniglot', 'open_images_challenge2019_detection', 'open_images_v4', 'openbookqa', 'opinion_abstracts', 'opinosis', 'opus', 'oxford_flowers102', 'oxford_iiit_pet', 'para_crawl', 'patch_camelyon', 'paws_wiki', 'paws_x_wiki', 'pet_finder', 'pg19', 'piqa', 'places365_small', 'plant_leaves', 'plant_village', 'plantae_k', 'qa4mre', 'qasc', 'quac', 'quickdraw_bitmap', 'race', 'radon', 'reddit', 'reddit_disentanglement', 'reddit_tifu', 'resisc45', 'robonet', 'rock_paper_scissors', 'rock_you', 's3o4d', 'salient_span_wikipedia', 'samsum', 'savee', 'scan', 'scene_parse150', 'scicite', 'scientific_papers', 'sentiment140', 'shapes3d', 'siscore', 'smallnorb', 'snli', 'so2sat', 'speech_commands', 'spoken_digit', 'squad', 'stanford_dogs', 'stanford_online_products', 'starcraft_video', 'stl10', 'story_cloze', 'sun397', 'super_glue', 'svhn_cropped', 'ted_hrlr_translate', 'ted_multi_translate', 'tedlium', 'tf_flowers', 'the300w_lp', 'tiny_shakespeare', 'titanic', 'trec', 'trivia_qa', 'tydi_qa', 'uc_merced', 'ucf101', 'vctk', 'vgg_face2', 'visual_domain_decathlon', 'voc', 'voxceleb', 'voxforge', 'waymo_open_dataset', 'web_nlg', 'web_questions', 'wider_face', 'wiki40b', 'wiki_bio', 'wiki_table_questions', 'wiki_table_text', 'wikihow', 'wikipedia', 'wikipedia_toxicity_subtypes', 'wine_quality', 'winogrande', 'wmt14_translate', 'wmt15_translate', 'wmt16_translate', 'wmt17_translate', 'wmt18_translate', 'wmt19_translate', 'wmt_t2t_translate', 'wmt_translate', 'wordnet', 'wsc273', 'xnli', 'xquad', 'xsum', 'xtreme_pawsx', 'xtreme_xnli', 'yelp_polarity_reviews', 'yes_no']
データセットをロードする
tfds.load
データセットをロードする最も簡単な方法はtfds.load
です。そうなる:
- データをダウンロードして、
tfrecord
ファイルとして保存します。 -
tfrecord
をロードし、tfrecord
を作成しtf.data.Dataset
。
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data. <_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
いくつかの一般的な議論:
-
split=
:読み取る分割(例:'train'
、['train', 'test']
、'train[80%:]'
、...)。分割APIガイドをご覧ください。 -
shuffle_files=
:各エポック間でファイルをシャッフルするかどうかを制御します(TFDSは大きなデータセットを複数の小さなファイルに保存します)。 -
data_dir=
:データセットが保存される場所(デフォルトは~/tensorflow_datasets/
) -
with_info=True
:データセットメタデータを含むtfds.core.DatasetInfo
を返します download=False
:ダウンロードを無効にする
tfds.builder
tfds.load
は、 tfds.load
の薄いラッパーtfds.core.DatasetBuilder
。 tfds.core.DatasetBuilder
を使用して同じ出力を取得できます。
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
tfds build
CLI
特定のデータセットを生成する場合は、 tfds
コマンドラインを使用できます。例えば:
tfds build mnist
使用可能なフラグについては、ドキュメントを参照してください。
データセットを反復処理します
口述として
デフォルトでは、tf.data.Dataset
オブジェクトにはtf.Tensor
のdict
が含まれています。
ds = tfds.load('mnist', split='train')
ds = ds.take(1) # Only take a single example
for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
print(list(example.keys()))
image = example["image"]
label = example["label"]
print(image.shape, label)
['image', 'label'] (28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
dict
キーの名前と構造を見つけるには、カタログのデータセットのドキュメントをご覧ください。例: mnistドキュメント。
タプルとして( as_supervised=True
)
as_supervised=True
を使用すると、教師ありデータセットの代わりにタプル(features, label)
取得できます。
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in ds: # example is (image, label)
print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
As numpy( tfds.as_numpy
)
tfds.as_numpy
を使用してtfds.as_numpy
を変換します。
-
tf.Tensor
>np.array
tf.data.Dataset
- >Iterator[Tree[np.array]]
Tree
の任意の入れ子にすることができDict
、Tuple
)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in tfds.as_numpy(ds):
print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4
バッチ処理されたtf.Tensorとして( batch_size=-1
)
batch_size=-1
を使用すると、データセット全体を1つのバッチで読み込むことができます。
これをas_supervised=True
およびtfds.as_numpy
と組み合わせて、データを(np.array, np.array)
として取得できます。
image, label = tfds.as_numpy(tfds.load(
'mnist',
split='test',
batch_size=-1,
as_supervised=True,
))
print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)
データセットがメモリに収まる可能性があること、およびすべての例が同じ形状であることに注意してください。
エンドツーエンドのパイプラインを構築する
さらに進むには、次を参照してください。
- 完全なトレーニングパイプライン(バッチ処理、シャッフルなど)を確認するためのエンドツーエンドのKerasの例。
- パイプラインの速度を向上させるためのパフォーマンスガイド。
視覚化
tfds.as_dataframe
tf.data.Dataset
オブジェクトがに変換することができpandas.DataFrame
でtfds.as_dataframe
上で可視化することがコラボ。
- 追加
tfds.core.DatasetInfo
2番目の引数としてtfds.as_dataframe
画像、音声、テキスト、ビデオを、視覚化します... -
ds.take(x)
を使用して、最初のx
例のみを表示します。pandas.DataFrame
は、完全なデータセットをメモリ内にロードするため、表示に非常にコストがかかる可能性があります。
ds, info = tfds.load('mnist', split='train', with_info=True)
tfds.as_dataframe(ds.take(4), info)
tfds.show_examples
tfds.show_examples
はmatplotlib.figure.Figure
返します(現在サポートされているのは画像データセットのみ):
ds, info = tfds.load('mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, info)
データセットメタデータにアクセスする
すべてのビルダーには、データセットメタデータを含むtfds.core.DatasetInfo
オブジェクトが含まれてtfds.core.DatasetInfo
ます。
次の方法でアクセスできます。
-
tfds.load
API:
ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info
データセット情報には、データセットに関する追加情報(バージョン、引用、ホームページ、説明など)が含まれています。
print(info)
tfds.core.DatasetInfo( name='mnist', full_name='mnist/3.0.1', description=""" The MNIST database of handwritten digits. """, homepage='http://yann.lecun.com/exdb/mnist/', data_path='/home/kbuilder/tensorflow_datasets/mnist/3.0.1', download_size=11.06 MiB, dataset_size=21.00 MiB, features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), supervised_keys=('image', 'label'), splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>, }, citation="""@article{lecun2010mnist, title={MNIST handwritten digit database}, author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist}, volume={2}, year={2010} }""", )
機能のメタデータ(ラベル名、画像の形状など)
tfds.features.FeatureDict
アクセスしtfds.features.FeatureDict
。
info.features
FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })
クラスの数、ラベル名:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7)) # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 7 7
形状、dtypes:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()} {'image': tf.uint8, 'label': tf.int64} (28, 28, 1) <dtype: 'uint8'>
メタデータの分割(例:名前の分割、例の数など)
tfds.core.SplitDict
アクセスしtfds.core.SplitDict
。
print(info.splits)
{'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>}
利用可能な分割:
print(list(info.splits.keys()))
['test', 'train']
個々の分割に関する情報を取得します。
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000 ['mnist-train.tfrecord-00000-of-00001'] 1
サブスプリットAPIでも機能します。
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000 [FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]
トラブルシューティング
手動ダウンロード(ダウンロードが失敗した場合)
何らかの理由でダウンロードが失敗した場合(オフラインなど)。いつでも自分でデータを手動でダウンロードして、 manual_dir
配置できます(デフォルトは~/tensorflow_datasets/download/manual/
です。
ダウンロードするURLを見つけるには、以下を調べてください。
新しいデータセット(フォルダーとして実装)の場合:
tensorflow_datasets/
<type>/<dataset_name>/checksums.tsv
。例:tensorflow_datasets/text/bool_q/checksums.tsv
。データセットのソースの場所は、カタログにあります。
古いデータセットの場合:
tensorflow_datasets/url_checksums/<dataset_name>.txt
NonMatchingChecksumError
修正
TFDSは、ダウンロードされたURLのチェックサムを検証することにより、決定論を保証します。 NonMatchingChecksumError
が発生した場合は、次のことを示している可能性があります。
- ウェブサイトがダウンしている可能性があります(例:
503 status code
)。 URLを確認してください。 - GoogleドライブのURLについては、同じURLにアクセスする人が多すぎるとドライブがダウンロードを拒否することがあるため、後でもう一度やり直してください。バグを見る
- 元のデータセットファイルが更新されている可能性があります。この場合、TFDSデータセットビルダーを更新する必要があります。新しいGithubの問題またはPRを開いてください:
- 新しいチェックサムを
tfds build --register_checksums
登録します - 最終的に、データセット生成コードを更新します。
- データセットの
VERSION
更新する - データセットを更新する
RELEASE_NOTES
:チェックサムが変更された原因は何ですか?いくつかの例は変更されましたか? - データセットを引き続き構築できることを確認してください。
- PRを送ってください
- 新しいチェックサムを
引用
論文にtensorflow-datasets
を使用している場合は、使用されているデータセット(データセットカタログにあります)に固有の引用に加えて、次の引用を含めてください。
@misc{TFDS,
title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
howpublished = {\url{https://www.tensorflow.org/datasets} },
}