Jax と PyTorch 用の TFDS

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

TFDS は常に フレームワーク非依存型でした。たとえば、NumPy 形式のデータセットを簡単に読み込んで、Jax と PyTorch で使用することができます。

TensorFlow とそのデータ読み込みソリューション(tf.data)は、設計上、API の第一級市民です。

TensorFlow を使用せずに NumPy のみでデータを読み込めるように、TFDS を拡張しました。これは、Jax や PyTorch などの ML での使用に便利であり、実際に PyTorch ユーザーの場合、TensorFlow では以下のことが発生する可能性があります。

  • GPU/TPU メモリの予約
  • CI/CD でのビルド時間の長期化
  • ランタイム時のインポートの長期化

TensorFlow は、データセットを読み取る際の依存関係ではなくなりました。

ML パイプラインがサンプルを読み込んで解読し、モデルに提供するには、データローダーが必要です。データローダーは、「ソース/サンプラー/ローダー」パラダイムを使用します。

 TFDS dataset       ┌────────────────┐
   on disk          │                │
        ┌──────────►│      Data      │
|..|... │     |     │     source     ├─┐
├──┼────┴─────┤     │                │ │
│12│image12   │     └────────────────┘ │    ┌────────────────┐
├──┼──────────┤                        │    │                │
│13│image13   │                        ├───►│      Data      ├───► ML pipeline
├──┼──────────┤                        │    │     loader     │
│14│image14   │     ┌────────────────┐ │    │                │
├──┼──────────┤     │                │ │    └────────────────┘
|..|...       |     │     Index      ├─┘
                    │    sampler     │
                    │                │
                    └────────────────┘
  • データソースは、TFDS データセットからオンザフライ方式でサンプルにアクセスして解読します。
  • インデックスサンプラーは、レコードが処理される順序を決定します。これは、レコードを読み取る前にグローバル変換(グローバルシャッフル、シャーディング、複数のエポックの反復など)を実装するのに重要です。
  • データローダーは、データソースとインデックスサンプラーを利用して、読み込みをオーケストレーションします。パフォーマンスの最適化が可能です(プリフェッチ、マルチプロセッシング、またはマルチスレッドなど)。

要約

tfds.data_source は、データソースを作成する API で、以下を目的としています。

  1. 純粋な Python パイプラインでの高速プロトタイピング
  2. 大規模なデータ集約型 ML パイプラインの管理

セットアップ

必要な依存関係をインストールしてインポートしましょう。

!pip install array_record
!pip install tfds-nightly

import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

データソース

データソースは基本的に Python シーケンスです。そのため、以下のプロトコルを実装する必要があります。

class RandomAccessDataSource(Protocol):
  """Interface for datasources where storage supports efficient random access."""

  def __len__(self) -> int:
    """Number of records in the dataset."""

  def __getitem__(self, record_key: int) -> Sequence[Any]:
    """Retrieves records for the given record_keys."""

警告: この API は現在も活発に開発されています。特に、現時点では、__getitem__ は入力で intlist[int] をサポートする必要があります。将来的には、標準に従って、おそらく int のみがサポートされます。

基盤のファイル形式は有効なランダムアクセスをサポートする必要があります。現時点では、TFDS は array_record に依存しています。

array_record は、Riegeli から派生した新しいファイル形式です。IO 効率の新境地を達成しています。特に、ArrayRecord はレコードインデックスによる同時読み取り、書き込み、およびランダムアクセスをサポートしています。ArrayRecord は Riegeli を基盤としているため、同じ圧縮アルゴリズムをサポートしています。

fashion_mnist はコンピュータビジョン用の共通データセットです。以下を使用するだけで、TFDS で ArrayRecord ベースのデータを取得することができます。

ds = tfds.data_source('fashion_mnist')

tfds.data_source は便利なラッパーで、以下に相当します。

builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()

これは、データソースのディクショナリを出力します。

{
  'train': DataSource(name=fashion_mnist, split='train', decoders=None),
  'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}

download_and_prepare が実行し、レコードファイルを生成したら、TensorFlow は不要になります。すべては Python/NumPy で処理されます!

TensorFlow をアンインストールして、別のサブプロセスでデータソースを読み込みなおして、このことを確認してみましょう。

pip uninstall -y tensorflow
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

try:
  import tensorflow as tf
except ImportError:
  print('No TensorFlow found...')

ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
python no_tensorflow.py

今後のバージョンでは、データセットの準備も TensorFlow を使用せずに行えるようにする予定です。

データソースには長さがあります。

len(ds['train'])

以下のようにして、データセットの最初の要素にアクセスすると...

%%timeit
ds['train'][0]

他の要素へのアクセスと同じように安価に行えます。これが、ランダムアクセスの定義です。

%%timeit
ds['train'][1000]

特徴量は NumPy DTypes(TensorFlow DTypes ではなく)を使用するようになりました。以下のようにして、特徴量を検査することができます。

features = tfds.builder('fashion_mnist').info.features

特徴量の詳細は、ドキュメントで確認できます。ここでは、画像の形状とクラスの数を取得できます。

shape = features['image'].shape
num_classes = features['label'].num_classes

純粋な Python で使用する

Python でデータソースを反復することで、それを消費できます。

for example in ds['train']:
  print(example)
  break

要素を検査すると、すべての特徴量がすでに NumPy を使って解読されているのがわかります。高速であるため、背後では、デフォルトで OpenCV を使用していますが、OpenCV がインストールされていない場合は、軽量で、画像を高速解読できる Pillow がデフォルトで使用されます。

{
  'image': array([[[0], [0], ..., [0]],
                  [[0], [0], ..., [0]]], dtype=uint8),
  'label': 2,
}

注意: 現在、この機能は、TensorImage、および Scalar の特徴量でしか使用できません。AudioVideo 特徴量は、間もなくサポートされる予定です。ご期待ください!

PyTorch で使用する

PyTorch は、ソース/サンプラー/ローダー構成のパラダイムを使用します。Torch では、「データソース」のことを「データセット」と呼んでいます。torch.utils.data には、有効な入力パイプラインを Torch でビルドするために必要なすべての情報が含まれます。

通常のマップスタイルのデータセットとして、TFDS データソースを使用することができます。

まず、Torch をインストールしてインポートします。

!pip install torch

from tqdm import tqdm
import torch

トレーニング用のデータソースとテスト用のデータソースはすでに定義済みです(順に、ds['train']ds['test'])。サンプラーとローダーを定義しましょう。

batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
    ds['train'],
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    ds['test'],
    sampler=None,
    batch_size=batch_size,
)

PyTorch で、最初のサンプルを使って、単純なロジスティック回帰をトレーニングし、評価します。

class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)


model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')

近日公開: JAX と使用する

Grain と緊密に作業を続けています。Grain はオープンソースの高速で決定論的な Python 用データローダーです。ご期待ください!

その他の資料

詳細については、tfds.data_source API ドキュメントをご覧ください。