このページは Cloud Translation API によって翻訳されました。
Switch to English

tf.data:TensorFlow入力パイプラインを構築する

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロード

tf.data APIを使用すると、単純で再利用可能な部分から複雑な入力パイプラインを構築できます。たとえば、画像モデルのパイプラインは、分散ファイルシステム内のファイルからデータを集約し、各画像にランダムな摂動を適用し、ランダムに選択された画像をトレーニング用のバッチにマージする場合があります。テキストモデルのパイプラインには、生のテキストデータからシンボルを抽出し、それらをルックアップテーブルを使用して埋め込み識別子に変換し、さまざまな長さのシーケンスをバッチ処理することが含まれる場合があります。 tf.data APIを使用すると、大量のデータを処理したり、さまざまなデータ形式から読み取ったり、複雑な変換を実行したりできます。

tf.data APIは、要素のシーケンスを表すtf.data.Dataset抽象化を導入します。各要素は、1つ以上のコンポーネントで構成されます。たとえば、画像パイプラインでは、要素は単一のトレーニング例であり、画像とそのラベルを表すテンソルコンポーネントのペアがあります。

データセットを作成するには、2つの異なる方法があります。

  • データソースは、メモリまたは1つ以上のファイルに格納されているDatasetからデータDatasetを構築します。

  • データ変換は、1つ以上のtf.data.Datasetオブジェクトからデータセットを構築します。

import tensorflow as tf
import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

基本的な仕組み

入力パイプラインを作成するには、データソースから開始する必要があります。たとえば、メモリ内のDatasetからデータDatasetを構築するには、 tf.data.Dataset.from_tensors()またはtf.data.Dataset.from_tensor_slices()使用できます。または、入力データが推奨されるTFRecord形式のファイルに保存されている場合は、 tf.data.TFRecordDataset()使用できます。

Datasetオブジェクトをtf.data.Datasettf.data.Datasetオブジェクトのメソッド呼び出しをチェーンすることで、それを新しいDataset変換できます。たとえば、次のような要素毎の変換を適用することができるDataset.map()およびなどの多要素変換Dataset.batch()変換の完全なリストについては、 tf.data.Datasetのドキュメントを参照してください。

DatasetオブジェクトはPythonの反復可能オブジェクトです。これにより、forループを使用してその要素を消費することが可能になります。

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<TensorSliceDataset shapes: (), types: tf.int32>
for elem in dataset:
  print(elem.numpy())
8
3
0
8
2
1

または、 iterを使用してPythonイテレーターを明示的に作成し、 nextを使用してその要素を消費することによって:

it = iter(dataset)

print(next(it).numpy())
8

または、 reduce変換を使用してデータセット要素を使用することもできます。これにより、すべての要素が縮小されて単一の結果が生成されます。次の例は、 reduce変換を使用して整数のデータセットの合計を計算する方法を示しています。

print(dataset.reduce(0, lambda state, value: state + value).numpy())
22

データセットの構造

データセットは、それぞれが同一の(ネストされた)構造を有し、構造体の個々の成分により、任意のタイプの表現であることができる要素含まtf.TypeSpec含む、 tf.Tensortf.sparse.SparseTensortf.RaggedTensortf.TensorArray 、またはtf.data.Dataset

Dataset.element_specプロパティを使用すると、各要素コンポーネントのタイプを検査できます。プロパティは、要素の構造に一致するtf.TypeSpecオブジェクトのネストされた構造を返します。これは、単一のコンポーネント、コンポーネントのタプル、またはコンポーネントのネストされたタプルの場合があります。例えば:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset変換は、任意の構造のデータセットをサポートします。使用する場合Dataset.map()及びDataset.filter()各要素に関数を適用する変換を、素子構造は、関数の引数を決定します。

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1
<TensorSliceDataset shapes: (10,), types: tf.int32>
for z in dataset1:
  print(z.numpy())
[1 2 5 3 8 9 6 7 9 3]
[5 4 3 3 9 7 3 1 5 4]
[7 3 4 8 4 8 5 9 4 1]
[8 6 8 2 6 3 5 8 3 3]

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2
<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3
<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>
for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

入力データの読み取り

NumPy配列の消費

その他の例については、 NumPy配列の読み込みを参照してください。

すべての入力データがメモリに収まる場合、それらからDatasetを作成する最も簡単な方法は、それらをtf.Tensorオブジェクトに変換し、 Dataset.from_tensor_slices()を使用することDataset.from_tensor_slices()

train, test = tf.keras.datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 1s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

Pythonジェネレーターの使用

tf.data.Datasetとして簡単に取り込むことができるもう1つの一般的なデータソースは、Pythonジェネレーターです。

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)
0
1
2
3
4

Dataset.from_generatorコンストラクターは、Pythonジェネレーターを完全に機能するtf.data.Datasetます。

コンストラクターは、イテレーターではなく、呼び出し可能オブジェクトを入力として受け取ります。これにより、ジェネレーターが最後に到達したときにジェネレーターを再起動できます。オプションのargs引数を取ります。これは、呼び出し可能引数として渡されます。

output_typesための引数が必要であるtf.data構築tf.Graph内部、およびグラフのエッジは必要tf.dtype

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes引数は必須ではありませが、多くのテンソルフロー操作はランクが不明なテンソルをサポートしていないため、強くお勧めします。特定の軸の長さが不明または可変である場合は、 output_shapes Noneに設定します。

output_shapesoutput_typesは、他のデータセットメソッドと同じネストルールに従うことに注意することも重要です。

これは両方の側面を示すジェネレーターの例です。配列のタプルを返します。2番目の配列は長さが不明なベクトルです。

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
0 : [1.2226]
1 : [ 0.4785  1.1887 -0.2828  0.6047  1.3367 -0.4387  0.1822]
2 : [ 1.1343 -0.2676  0.0224 -0.111  -0.1384 -1.9315]
3 : [-0.8651]
4 : [0.6275]
5 : [ 0.8034  2.0773  0.6183 -0.3746]
6 : [-0.9439 -0.6686]

最初の出力はint32 、2番目の出力はfloat32です。

最初の項目はスカラーであるshape ()であり、2番目の項目は未知の長さのベクトルであるshape (None,)

ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series
<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

これで、通常のtf.data.Datasetように使用できます。可変形状のデータセットをバッチ処理する場合は、 Dataset.padded_batchを使用する必要があることにDataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[18  1  8 22  5 13  6  2 14 28]

[[ 0.0196 -1.007   0.1843  0.0289  0.0735 -0.6279 -1.0877  0.    ]
 [ 0.8238 -1.0284  0.      0.      0.      0.      0.      0.    ]
 [ 0.544  -1.1061  1.2368  0.3975  0.      0.      0.      0.    ]
 [-1.1933 -0.7535  1.0497  0.1764  1.5319  0.9202  0.4027  0.6844]
 [-1.2025 -2.7148  1.0702  1.3893  0.      0.      0.      0.    ]
 [ 1.3787  0.6817  0.1197 -0.1178  0.9764  0.8895  0.      0.    ]
 [-1.523  -0.7722 -2.13    0.2761 -1.1094  0.      0.      0.    ]
 [ 0.203   2.1858 -0.722   1.2554  1.2208  0.1813 -2.3427  0.    ]
 [ 0.7685  1.7138 -1.2376 -2.6168  0.2565  0.0753  1.5653  0.    ]
 [ 0.0908 -0.4535  0.1257 -0.5122  0.      0.      0.      0.    ]]

より現実的な例については、ラッピングしてみてくださいpreprocessing.image.ImageDataGenerator通りtf.data.Dataset

最初にデータをダウンロードします。

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 9s 0us/step

image.ImageDataGenerator作成します

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3)
float32 (32, 5)

ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers), 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds.element_spec
(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
for images, label in ds.take(1):
  print('images.shape: ', images.shape)
  print('labels.shape: ', labels.shape)

Found 3670 images belonging to 5 classes.
images.shape:  (32, 256, 256, 3)
labels.shape:  (32, 5)

TFRecordデータの消費

エンドツーエンドの例については、 TFRecordのロードを参照してください。

tf.data APIはさまざまなファイル形式をサポートしているため、メモリに収まらない大きなデータセットを処理できます。たとえば、TFRecordファイル形式は、多くのTensorFlowアプリケーションがデータのトレーニングに使用する単純なレコード指向のバイナリ形式です。 tf.data.TFRecordDatasetクラスを使用すると、入力パイプラインの一部として1つ以上のTFRecordファイルのコンテンツをストリーミングできます。

これは、フランスの道路標識(FSNS)のテストファイルを使用した例です。

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 0s 0us/step

filenames引数TFRecordDatasetいずれかの文字列、文字列のリスト、またはすることができイニシャライザtf.Tensor文字列の。したがって、トレーニングと検証の目的で2セットのファイルがある場合は、ファイル名を入力引数として使用して、データセットを生成するファクトリメソッドを作成できます。

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

多くのTensorFlowプロジェクトは、TFRecordファイルでシリアル化されたtf.train.Exampleレコードを使用します。これらは、検査する前にデコードする必要があります。

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
bytes_list {
  value: "Rue Perreyon"
}

テキストデータの消費

エンドツーエンドの例については、テキストの読み込みを参照してください。

多くのデータセットは、1つ以上のテキストファイルとして配布されます。 tf.data.TextLineDatasetは、1つ以上のテキストファイルから行を抽出する簡単な方法を提供します。 1つ以上のファイル名を指定すると、 TextLineDatasetは、それらのファイルの1行ごとに1つの文字列値要素を生成します。

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step

dataset = tf.data.TextLineDataset(file_paths)

最初のファイルの最初の数行は次のとおりです。

for line in dataset.take(5):
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

ファイル間で行をDataset.interleave使用しDataset.interleave 。これにより、ファイルをまとめてシャッフルするのが簡単になります。各翻訳の1行目、2行目、3行目は次のとおりです。

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

デフォルトでは、 TextLineDatasetは各ファイルのすべての行を生成します。これは、たとえば、ファイルがヘッダー行で始まる場合やコメントが含まれている場合など、望ましくない場合があります。これらの行は、 Dataset.skip()またはDataset.filter()変換を使用して削除できます。ここでは、最初の行をスキップしてから、フィルターをかけて生存者のみを検索します。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step

for line in titanic_lines.take(10):
  print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'

def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

CSVデータの消費

その他の例については、 CSVファイルの読み込みPandasDataFrameの読み込みを参照してください。

CSVファイル形式は、表形式のデータをプレーンテキストで保存するための一般的な形式です。

例えば:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
df.head()

データがメモリに収まる場合、同じDataset.from_tensor_slicesメソッドが辞書で機能し、このデータを簡単にインポートできるようになります。

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
  'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

よりスケーラブルなアプローチは、必要に応じてディスクからロードすることです。

tf.dataモジュールは、 tf.data準拠する1つ以上のCSVファイルからレコードを抽出するメソッドを提供します

experimental.make_csv_dataset関数は、csvファイルのセットを読み取るための高レベルのインターフェイスです。カラムタイプの推論や、バッチ処理やシャッフルなどの他の多くの機能をサポートして、使用法を簡単にします。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [0 0 0 0]
features:
  'sex'               : [b'male' b'male' b'male' b'male']
  'age'               : [18. 19. 31. 34.]
  'n_siblings_spouses': [0 0 0 1]
  'parch'             : [0 0 0 0]
  'fare'              : [ 8.3    8.05   7.775 26.   ]
  'class'             : [b'Third' b'Third' b'Third' b'Second']
  'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
  'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Southampton']
  'alone'             : [b'y' b'y' b'y' b'n']

列のサブセットのみが必要な場合は、 select_columns引数を使用できます。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 0 0 1]
  'fare'              : [ 20.25     9.     110.8833  59.4   ]
  'class'             : [b'Third' b'Third' b'First' b'First']

よりきめ細かい制御を提供する低レベルのexperimental.CsvDatasetクラスもあります。列タイプの推論はサポートしていません。代わりに、各列のタイプを指定する必要があります。

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string] 
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

一部の列が空の場合、この低レベルのインターフェースを使用すると、列タイプの代わりにデフォルト値を提供できます。

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset shapes: (4,), types: tf.int32>
for line in dataset:
  print(line.numpy())
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

デフォルトでは、 CsvDatasetはファイルのすべての行のすべての列を生成します。これは、たとえば、ファイルが無視する必要のあるヘッダー行で始まる場合や、入力に一部の列が不要な場合など、望ましくない場合があります。これらのラインおよびフィールドを用いて除去することができheaderselect_colsそれぞれ引数。

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset shapes: (2,), types: tf.int32>
for line in dataset:
  print(line.numpy())
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

ファイルのセットを消費する

ファイルのセットとして配布される多くのデータセットがあり、各ファイルが例です。

flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

ルートディレクトリには、各クラスのディレクトリが含まれています。

for item in flowers_root.glob("*"):
  print(item.name)
sunflowers
daisy
LICENSE.txt
roses
tulips
dandelion

各クラスディレクトリのファイルは例です。

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4933229095_f7e4218b28.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/18282528206_7fb3166041.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/daisy/3711723108_65247a3170.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4019748730_ee09b39a43.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/475936554_a2b38aaa8e.jpg'

tf.io.read_file関数を使用してデータを読み取り、パスからラベルを抽出して、 (image, label)ペアを返します。

def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xe2\x0cXICC_PROFILE\x00\x01\x01\x00\x00\x0cHLino\x02\x10\x00\x00mntrRGB XYZ \x07\xce\x00\x02\x00\t\x00\x06\x001\x00\x00acspMSFT\x00\x00\x00\x00IEC sRGB\x00\x00\x00\x00\x00\x00'

b'roses'

データセット要素のバッチ処理

簡単なバッチ処理

バッチ処理の最も単純な形式は、データセットのn連続する要素を1つの要素にスタックします。 Dataset.batch()変換は、要素の各コンポーネントに適用されるtf.stack()演算子と同じ制約で、これを正確に実行します。つまり、各コンポーネントiについて、すべての要素はまったく同じ形状のテンソルを持っている必要があります。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

tf.dataは形状情報を伝達しようとしますが、 Dataset.batchのデフォルト設定では、最後のバッチがいっぱいでない可能性があるため、バッチサイズが不明になります。形状のNone注意してください。

batched_dataset
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>

drop_remainder引数を使用して、その最後のバッチを無視し、完全な形状の伝播を取得します。

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
<BatchDataset shapes: ((7,), (7,)), types: (tf.int64, tf.int64)>

パディングによるテンソルのバッチ処理

上記のレシピは、すべて同じサイズのテンソルで機能します。ただし、多くのモデル(シーケンスモデルなど)は、さまざまなサイズ(長さの異なるシーケンスなど)の入力データを処理します。この場合を処理するために、 Dataset.padded_batchトランスフォーメーションを使用すると、 Dataset.padded_batchできる1つ以上の次元を指定することにより、異なる形状のテンソルをバッチ処理できます。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()

[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]


Dataset.padded_batchトランスフォーメーションを使用すると、各コンポーネントのディメンションごとに異なるパディングを設定できます。これは、可変長(上記の例ではNoneで示されます)または一定長の場合があります。デフォルトが0であるパディング値をオーバーライドすることもできます。

トレーニングワークフロー

複数のエポックを処理する

tf.data APIは、同じデータの複数のエポックを処理する2つの主な方法を提供します。

複数のエポックでデータセットを反復処理する最も簡単な方法は、 Dataset.repeat()変換を使用することです。まず、チタンデータのデータセットを作成します。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

引数なしでDataset.repeat()変換を適用すると、入力が無期限に繰り返されます。

Dataset.repeatトランスフォーメーションは、あるエポックの終了と次のエポックの開始を通知せずに、引数を連結します。このため、 Dataset.batch後に適用されたDataset.repeatは、エポック境界にまたがるバッチを生成します。

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

png

明確なエポック分離が必要な場合は、繰り返しの前にDataset.batch配置しDataset.batch

titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)

png

各エポックの最後にカスタム計算(統計の収集など)を実行する場合は、各エポックでデータセットの反復を再開するのが最も簡単です。

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

入力データをランダムにシャッフルする

Dataset.shuffle()トランスフォーメーションは、固定サイズのバッファーを維持し、そのバッファーからランダムに次の要素を均一に選択します。

効果を確認できるように、データセットにインデックスを追加します。

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.string)>

buffer_sizeが100で、バッチサイズが20であるため、最初のバッチには、インデックスが120を超える要素は含まれていません。

n,line_batch = next(iter(dataset))
print(n.numpy())
[ 29  39  94  10  81  85  74   3  91  53  87  17   1  64  54 107  20  63
 116  62]

Dataset.batch同様に、 Dataset.batchに関連するDataset.repeat重要です。

Dataset.shuffleは、シャッフルバッファーが空になるまで、エポックの終了をDataset.shuffleしません。したがって、リピートの前に配置されたシャッフルは、次のエポックに移動する前に、あるエポックのすべての要素を表示します。

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[469 560 614 431 557 558 618 616 571 532]
[464 403 509 589 365 610 596 428 600 605]
[539 467 163 599 624 545 552 548]
[92 30 73 51 72 71 56 83 32 46]
[ 74 106  75  78  26  21   3  54  80  50]

shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f79e43368d0>

png

しかし、シャッフルの前に繰り返すと、エポックの境界が混ざり合います。

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[618 517  22  13 602 608  25 621  12 579]
[568 620 534 547 319   1  17 346  29 581]
[537 395 560  33 617 557 561 610 585 627]
[428 607  32   0 455 496 605 624   6 596]
[511 588  14 572   4  21  66 471  41  36]
[583 622  49  54 616  75  65  57  76   7]
[ 77  37 574 586  74  56  28 419  73   5]
[599  83  31  61  91 481 577  98  26  50]
[  2 625  85  23 606 102 550  20 612  15]
[ 24  63 592 615 388  39  45  51  68  97]
[ 46  34  67  55  79  94 548  47  72 436]
[127  89  30  90 126 601 360  87 104  52]
[ 59 130 106 584 118 265  48  70 597 121]
[626 105 135  10  38  53 132 136 101   9]
[512 542 109 154 108  42 114  11  58 166]

repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f79e42b60b8>

png

データの前処理

Dataset.map(f)変換は、入力データセットの各要素に特定の関数fを適用することにより、新しいデータセットを生成します。これは、関数型プログラミング言語のリスト(およびその他の構造)に一般的に適用されるmap()関数に基づいていmap() 。関数fかかりtf.Tensor入力における単一の要素を表すオブジェクトを、そして返しtf.Tensorそれが新しいデータセット内の単一の要素を表現するオブジェクト。その実装では、標準のTensorFlow操作を使用して、ある要素を別の要素に変換します。

このセクションでは、 Dataset.map()使用方法の一般的な例について説明します。

画像データのデコードとサイズ変更

実世界の画像データでニューラルネットワークをトレーニングする場合、固定サイズにバッチ処理できるように、さまざまなサイズの画像を共通のサイズに変換する必要があることがよくあります。

花のファイル名データセットを再構築します。

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

データセット要素を操作する関数を記述します。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(filename, os.sep)
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label

それが機能することをテストします。

file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)

png

データセットにマッピングします。

images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)

png

png

任意のPythonロジックを適用する

パフォーマンス上の理由から、可能な限りデータを前処理するためにTensorFlow操作を使用してください。ただし、入力データを解析するときに外部Pythonライブラリを呼び出すと便利な場合があります。 Dataset.map()トランスフォーメーションでtf.py_function()操作を使用できます。

たとえば、ランダムな回転を適用する場合、 tf.imageモジュールにはtf.image.rot90しかありません。これは、画像の拡張にはあまり役立ちません。

tf.py_functionを示すには、代わりにscipy.ndimage.rotate関数を使用してみてください。

import scipy.ndimage as ndimage

def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

で、この機能を使用するにはDataset.map同じ警告を持つように適用Dataset.from_generator 、あなたが関数を適用する場合、リターン形状や種類を記述する必要があります。

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label
rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

png

tf.Exampleプロトコルバッファメッセージの解析

多くの入力パイプラインは、TFRecord形式からtf.train.Exampleプロトコルバッファメッセージを抽出します。各tf.train.Exampleレコードには、1つ以上の「機能」が含まれており、入力パイプラインは通常、これらの機能をテンソルに変換します。

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

tf.train.Example外部でtf.data.Datasettf.data.Datasetして、データを理解できます。

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

png

raw_example = next(iter(dataset))
def tf_parse(eg):
  example = tf.io.parse_example(
      eg[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...

decoded = dataset.map(tf_parse)
decoded
<MapDataset shapes: ((), ()), types: (tf.string, tf.string)>
image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])

時系列ウィンドウ処理

エンドツーエンドの時系列の例については、時系列予測を参照してください。

時系列データは、多くの場合、時間軸をそのままにして編成されます。

簡単なDataset.rangeを使用して、 Dataset.rangeことを示します。

range_ds = tf.data.Dataset.range(100000)

通常、この種のデータに基づくモデルでは、連続したタイムスライスが必要になります。

最も簡単なアプローチは、データをバッチ処理することです。

batch使用

batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

または、将来の1つのステップで密な予測を行うには、フィーチャとラベルを相互に1ステップずつシフトします。

def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

固定オフセットではなくウィンドウ全体を予測するには、バッチを2つの部分に分割します。

batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Take the first 5 steps
          batch[-5:])   # take the remainder

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

あるバッチの機能と別のバッチのラベル​​の重複を許可するには、 Dataset.zip使用しDataset.zip

feature_length = 10
label_length = 5

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])

predict_5_steps = tf.data.Dataset.zip((features, labels))

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22 23 24]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32 33 34]

windowを使用する

Dataset.batch使用はDataset.batchますが、より細かい制御が必要になる場合があります。 Dataset.windowメソッドは完全な制御を提供しますが、注意が必要です。 DatasetDatasets Datasetを返します。詳細については、データセットの構造を参照してください。

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>

Dataset.flat_mapメソッドは、データセットのデータセットを取得し、それを単一のデータセットにフラット化できます。

 for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f79e44309d8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7f79e44309d8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9 

ほとんどすべての場合、最初にデータセットを.batchする必要があります。

def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

これで、 shift引数が各ウィンドウの移動量を制御していることがわかります。

これをまとめると、次の関数を作成できます。

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows

ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

次に、前と同じように、ラベルを簡単に抽出できます。

dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

リサンプリング

クラスが非常に不均衡なデータセットを操作する場合は、データセットをリサンプリングすることをお勧めします。 tf.dataは、これを行うための2つの方法を提供します。クレジットカード詐欺データセットは、この種の問題の良い例です。

zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69156864/69155632 [==============================] - 3s 0us/step

creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

ここで、クラスの分布を確認してください。非常に偏っています。

def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
[0.9957 0.0043]

不均衡なデータセットを使用したトレーニングの一般的なアプローチは、バランスを取ることです。 tf.dataは、このワークフローを可能にするいくつかのメソッドが含まれています。

データセットのサンプリング

データセットをリサンプリングする1つの方法は、 sample_from_datasetsを使用することsample_from_datasets 。これは、クラスごとに個別のdata.Datasetがある場合により適しています。

ここでは、フィルターを使用して、クレジットカードの不正データからそれらを生成します。

negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f79e43ce510> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7f79e43ce510> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f79e43ceae8> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7f79e43ceae8> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]

tf.data.experimental.sample_from_datasetsを使用するには、データセットとそれぞれの重みを渡します。

balanced_ds = tf.data.experimental.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

これで、データセットは50/50の確率で各クラスの例を生成します。

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[0 1 1 1 1 0 0 1 0 1]
[0 0 0 0 1 1 1 0 0 1]
[1 0 1 0 1 1 0 0 1 0]
[0 0 0 0 1 0 1 1 0 0]
[0 1 1 1 0 1 1 1 1 1]
[0 1 0 1 1 0 1 1 0 1]
[0 0 1 1 1 0 1 1 1 1]
[1 1 0 0 1 1 0 1 1 0]
[0 1 0 0 1 0 0 0 0 0]
[0 1 1 1 0 1 0 1 1 0]

拒否のリサンプリング

上記のexperimental.sample_from_datasetsアプローチの問題の1つは、クラスごとに個別のtf.data.Datasetが必要なことです。 Dataset.filter使用はDataset.filterますが、すべてのデータが2回読み込まれます。

data.experimental.rejection_resample関数をデータセットに適用して、データセットを1回だけロードしながらリバランスすることができます。バランスをとるために、要素はデータセットから削除されます。

data.experimental.rejection_resampleclass_func引数を取ります。このclass_funcは各データセット要素に適用され、バランスをとるために例がどのクラスに属するかを決定するために使用されます。

creditcard_dsの要素は、すでに(features, label)ペアです。したがって、 class_funcこれらのラベルを返す必要があります。

def class_func(features, label):
  return label

リサンプラーには、ターゲット分布と、オプションで初期分布推定も必要です。

resampler = tf.data.experimental.rejection_resample(
    class_func, target_dist=[0.5, 0.5], initial_dist=fractions)

リサンプラーは個々の例を扱うため、リサンプラーを適用する前にデータセットのunbatchする必要があります。

resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/experimental/ops/resampling.py:156: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:


リサンプラーは、 class_func出力からcreates (class, example)ペアを返します。この場合、 exampleはすでに(feature, label)ペアであったため、 mapを使用してラベルの余分なコピーを削除します。

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

これで、データセットは50/50の確率で各クラスの例を生成します。

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[1 0 1 1 0 0 0 0 0 0]
[1 0 1 0 1 1 0 1 0 1]
[0 1 1 0 0 1 0 0 0 0]
[1 0 1 0 0 1 0 1 0 0]
[1 1 1 1 0 0 0 0 0 1]
[1 0 1 1 1 0 1 0 1 1]
[0 0 1 1 1 1 0 0 1 1]
[0 0 1 0 0 1 1 1 0 1]
[1 0 0 0 1 0 0 0 0 0]
[0 0 0 0 1 0 0 0 1 1]

イテレータチェックポイント

Tensorflowはチェックポイントの取得をサポートしているため、トレーニングプロセスが再開すると、最新のチェックポイントを復元して、進行状況のほとんどを回復できます。モデル変数のチェックポイントに加えて、データセットイテレーターの進行状況をチェックポイントすることもできます。これは、データセットが大きく、再起動するたびにデータセットを最初から開始したくない場合に役立ちます。ただし、 shuffleprefetchなどの変換にはイテレータ内のバッファリング要素が必要なため、イテレータのチェックポイントが大きくなる可能性があることに注意してください。

イテレータをチェックポイントに含めるには、イテレータをtf.train.Checkpointコンストラクタに渡します。

range_ds = tf.data.Dataset.range(20)

iterator = iter(range_ds)
ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)

print([next(iterator).numpy() for _ in range(5)])

save_path = manager.save()

print([next(iterator).numpy() for _ in range(5)])

ckpt.restore(manager.latest_checkpoint)

print([next(iterator).numpy() for _ in range(5)])
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[5, 6, 7, 8, 9]

tf.kerasでtf.dataを使用する

tf.keras APIは、機械学習モデルの作成と実行の多くの側面を簡素化します。その.fit()および.evaluate()および.predict() APIは、入力としてデータセットをサポートします。簡単なデータセットとモデルの設定は次のとおりです。

train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=['accuracy'])

Model.fitModel.evaluate必要なのは(feature, label)ペアのデータセットを渡すことModel.evaluateです。

model.fit(fmnist_train_ds, epochs=2)
Epoch 1/2
WARNING:tensorflow:Layer flatten is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1875/1875 [==============================] - 3s 2ms/step - loss: 0.5973 - accuracy: 0.7969
Epoch 2/2
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4629 - accuracy: 0.8407

<tensorflow.python.keras.callbacks.History at 0x7f79e43d8710>

たとえばDataset.repeat()呼び出すことによって無限のデータセットを渡す場合は、 steps_per_epoch引数も渡す必要があります。

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4800 - accuracy: 0.8594
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4529 - accuracy: 0.8391

<tensorflow.python.keras.callbacks.History at 0x7f79e43d84a8>

評価のために、いくつかの評価ステップを渡すことができます。

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4386 - accuracy: 0.8508
Loss : 0.4385797381401062
Accuracy : 0.8507500290870667

長いデータセットの場合、評価するステップ数を設定します。

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
10/10 [==============================] - 0s 2ms/step - loss: 0.3737 - accuracy: 0.8656
Loss : 0.37370163202285767
Accuracy : 0.8656250238418579

Model.predict呼び出す場合、ラベルは必要ありません。

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
(320, 10)

ただし、ラベルを含むデータセットを渡すと、ラベルは無視されます。

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
(320, 10)