TensorFlow Datasets

TFDS は、TensorFlow、Jax、およびその他の機械学習フレームワークですぐに使用できる一連のデータセットを提供しています。

データの確定的なダウンロードと準備、および tf.data.Dataset(または np.array)の構築を行います。

注意: TFDS(このライブラリ)と tf.data(有効なデータパイプラインを構築する TensorFlow API)を混同しないようにしてください。TFDS は tf.data を囲む高レベルのラッパーです。この API をよく知らない方は、まず tf.data の公式ガイドを読むことをお勧めします。

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

インストール

TFDS は 2 つのパッケージで存在します。

  • pip install tensorflow-datasets: 安定バージョン。数か月おきにリリースされます。
  • pip install tfds-nightly: 毎日リリースされ、データセットの最終バージョンが含まれます。

この Colabでは、tfds-nightly を使用します。

pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

利用可能なデータセットを特定する

すべてのデータセットビルダーは、tfds.core.DatasetBuilder のサブクラスです。利用可能なビルダーのリストを取得するには、tfds.list_builders() を使用するか、カタログをご覧ください。

tfds.list_builders()

データセットを読み込む

tfds.load

データセットを最も簡単に読み込むには、tfds.load を使用します。次の内容が行われます。

  1. データをダウンロードし、tfrecord ファイルとして保存します。
  2. tfrecord を読み込んで、tf.data.Dataset を作成します。
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)

次ような一般的な属性があります。

  • split=: 分割して読み取ります('train'['train', 'test']'train[80%:]'、など)。split API ガイドをご覧ください。
  • shuffle_files=: エポック間でファイルをシャッフルするかどうかを制御します(TFDS は大規模なデータセットを複数の小さなファイルに保存します)。
  • data_dir=: データセットが保存される場所(デフォルトは ~/tensorflow_datasets/
  • with_info=True: データセットのメタデータを含む tfds.core.DatasetInfo を返します。
  • download=False: ダウンロードを無効にします。

tfds.builder

tfds.load は、tfds.core.DatasetBuilder の新ラッパーです。tfds.core.DatasetBuilder API を使用して同じ出力を取得することができます。

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)

tfds build CLI

特定のデータセットを生成する場合は、tfds コマンドライン を使用できます。以下に例を示します。

tfds build mnist

利用できるフラグについては、ドキュメントをご覧ください。

データセットをイテレートする

dict 型

デフォルトでは、tf.data.Dataset オブジェクトには、tf.Tensordict が含まれます。

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)

dict のキー名と構造を知るには、カタログのデータセットドキュメントを確認します。例: mnist ドキュメント

タプル (as_supervised=True)

as_supervised=True を使用すると、スーパーバイズされたデータセットの代わりにタプル型 (features, label) を取得することができます。

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)

NumPy 配列 (tfds.as_numpy)

tfds.as_numpy を使用して、次のように変換します。

  • tf.Tensor -> np.array
  • tf.data.Dataset -> Iterator[Tree[np.array]]Tree は任意のネストされた DictTuple
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)

バッチ化された tf.Tensor(batch_size=-1

batch_size=-1 を使用すると、単一のバッチで全データセットを読み込むことができます。

これは as_supervised=Truetfds.as_numpy に組み合わせることで、データを (np.array, np.array) として取得できます。

image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test',
    batch_size=-1,
    as_supervised=True,
))

print(type(image), image.shape)

データセットがメモリに収まる可能性があり、すべてのサンプルの形状が同じであることに注意してください。

データセットのベンチマークを作成する

データセットのベンチマーク作成は、単純な tfds.benchmark 呼び出しを任意のイテラブル(tf.data.Datasettfds.as_numpy など)に対して行います。

ds = tfds.load('mnist', split='train')
ds = ds.batch(32).prefetch(1)

tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32)  # Second epoch much faster due to auto-caching
  • batch_size= kwarg でバッチサイズごとに結果を必ず正規化してください。
  • まとめると、最初のウォームアップバッチは、tf.data.Dataset の追加のセットアップ時間(バッファの初期化など)をキャプチャするために、ほかのバッチとは分離されます。
  • 2 つ目のイテレーションが、TFDS auto-caching にいよってはるかに高速に行われるのがわかります。
  • tfds.benchmark は以降の分析で検査できる tfds.core.BenchmarkResult を返します。

エンドツーエンドパイプラインを構築する

先に進むには、次の項目をご覧ください。

視覚化する

tfds.as_dataframe

tf.data.Dataset オブジェクトは、pandas.DataFrame に変換可能で、その tfds.as_dataframeColab で視覚化できます。

  • 画像、オーディオ、テキスト、動画などを視覚化するには、tfds.as_dataframe の 2 つ目の引数として tfds.core.DatasetInfo を追加します。
  • 最初の x サンプルのみを表示するには、ds.take(x) を使用します。pandas.DataFrame はメモリ内の全データセットを読み込むため、表示するには非常にコストがかかる可能性があります。
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

tfds.show_examples

tfds.show_examplesmatplotlib.figure.Figure を返します(現在は画像データセットのみがサポートされています)。

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

データセットのメタデータにアクセスする

すべてのビルダーには、データセットのメタデータを含む tfds.core.DatasetInfo オブジェクトが含まれます。

次の方法でアクセスできます。

ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info

データセット情報には、データセットに関する追加情報(バージョン、引用、ホームページ、説明など)が含まれます。

print(info)

特徴量メタデータ(ラベル名、画像形状など)

次のようにして tfds.features.FeatureDict にアクセスします。

info.features

クラス数、ラベル名:

print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))

形状、dtype:

print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)

分割メタデータ(分割名、サンプル数など)

次のようにして tfds.core.SplitDict にアクセスします。

print(info.splits)

利用可能な分割:

print(list(info.splits.keys()))

個々の分割に関する情報の取得:

print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)

subsplit API でも動作します。

print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)

トラブルシューティング

手動ダウンロード(ダウンロードに失敗した場合)

何らかの理由(オフラインなど)でダウンロードに失敗した場合は、手動でデータをダウンロードして manual_dir(デフォルトは ~/tensorflow_datasets/download/manual/)にダウンロードすることができます。

ダウンロードする URL を見つけるには、次を確認してください。

NonMatchingChecksumError を修正する

TFDS では、ダウンロード URL のチェックサムを検証することで、決定性を確保しています。NonMatchingChecksumError が発生する場合は、以下のことが考えられます。

  • ウェブサイトがダウンしている(503 ステータスコードなど)。URL を確認してください。
  • Google Drive の URL の場合は、Drive の同じ URL に多くの人がアクセスしている場合にダウンロードを拒否することがあるため、後でもう一度試してみてください。バグをご覧ください。
  • 元のデータセットファイルが更新されている。この場合、TFDS データセットビルダーを更新する必要があります。新しい Github issue か PR を発行してください。
    • tfds build --register_checksums で新しいチェックサムを登録します。
    • 最終的に、データセットの生成コードを更新します。
    • データセットの VERSION を更新します。
    • データセットの RELEASE_NOTES を更新します。チェックサムの変更理由、Example の変更有無など。
    • データセットが構築可能であることを確認します。
    • PR を送信します。

注意: ~/tensorflow_datasets/download/ でダウンロードされたファイルを検査することも可能です。

引用

論文で tensorflow-datasets を使用する場合は、使用したデータセットの固有の引用(dataset catalog で確認できます)のほかに次の引用を含めてください。

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets} },
}