このページは Cloud Translation API によって翻訳されました。
Switch to English

tf.data:TensorFlow入力パイプラインを構築する

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示 ノートブックをダウンロード

tf.data APIを使用すると、単純で再利用可能な部分から複雑な入力パイプラインを構築できます。たとえば、画像モデルのパイプラインは、分散ファイルシステムのファイルからデータを集約し、ランダムな摂動を各画像に適用し、ランダムに選択された画像をトレーニング用のバッチにマージします。テキストモデルのパイプラインには、生のテキストデータからシンボルを抽出し、それらをルックアップテーブルを使用して埋め込み識別子に変換し、異なる長さのシーケンスをバッチ処理することが含まれます。 tf.data APIを使用すると、大量のデータを処理し、さまざまなデータ形式から読み取り、複雑な変換を実行できます。

tf.data APIは、各要素が1つ以上のコンポーネントで構成される要素のシーケンスを表すtf.data.Dataset抽象化を導入します。たとえば、画像パイプラインでは、要素は画像とそのラベルを表す1組のテンソルコンポーネントを含む単一のトレーニングの例です。

データセットを作成するには、2つの異なる方法があります。

  • データソースは、メモリまたは1つ以上のファイルに格納されたDatasetからデータDatasetを構築します。

  • データ変換は 、1つ以上のtf.data.Datasetオブジェクトからデータセットを構築します。

 import tensorflow as tf
 
 import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)
 

基本的な力学

入力パイプラインを作成するには、データソースから始める必要があります。たとえば、メモリ内のDatasetからデータDatasetを構築するには、 tf.data.Dataset.from_tensors()またはtf.data.Dataset.from_tensor_slices()使用できます。または、入力データが推奨されるTFRecord形式のファイルに保存されている場合は、 tf.data.TFRecordDataset()使用できます。

Datasetオブジェクトをtf.data.Datasettf.data.Datasetオブジェクトのメソッド呼び出しをチェーンすることで、新しいDataset 変換できます。たとえば、次のような要素毎の変換を適用することができるDataset.map()およびなどの多要素変換Dataset.batch()変換の完全なリストについては、 tf.data.Datasetのドキュメントを参照してください。

DatasetオブジェクトはPythonの反復可能オブジェクトです。これにより、forループを使用してその要素を使用できます。

 dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
 
<TensorSliceDataset shapes: (), types: tf.int32>
 for elem in dataset:
  print(elem.numpy())
 
8
3
0
8
2
1

または、 iterを使用して明示的にPythonイテレータを作成し、 nextを使用してその要素を消費することによって:

 it = iter(dataset)

print(next(it).numpy())
 
8

または、 reduce変換を使用してデータセット要素を消費することもできます。これにより、すべての要素が削減され、単一の結果が生成されます。次の例は、 reduce変換を使用して整数のデータセットの合計を計算する方法を示しています。

 print(dataset.reduce(0, lambda state, value: state + value).numpy())
 
22

データセットの構造

データセットは、それぞれが同一の(ネストされた)構造を有し、構造体の個々の成分により、任意のタイプの表現であることができる要素含まtf.TypeSpec含む、 tf.Tensortf.sparse.SparseTensortf.RaggedTensortf.TensorArray 、またはtf.data.Dataset

Dataset.element_specプロパティを使用すると、各要素コンポーネントのタイプを検査できます。プロパティは、要素の構造と一致するtf.TypeSpecオブジェクトのネストされた構造を返します。これは、単一のコンポーネント、コンポーネントのタプル、またはコンポーネントのネストされたタプルです。例えば:

 dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
 
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
 dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
 
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
 dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
 
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
 # Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
 
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
 # Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
 
tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset変換は、あらゆる構造のデータセットをサポートします。使用する場合Dataset.map()及びDataset.filter()各要素に関数を適用する変換を、素子構造は、関数の引数を決定します。

 dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1
 
<TensorSliceDataset shapes: (10,), types: tf.int32>
 for z in dataset1:
  print(z.numpy())
 
[8 1 2 6 1 7 2 6 1 3]
[6 5 6 5 3 5 2 5 3 6]
[5 8 4 8 3 1 4 6 4 8]
[2 4 5 8 3 5 7 9 4 2]

 dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2
 
<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>
 dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3
 
<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>
 for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
 
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

入力データの読み取り

NumPy配列の使用

その他の例については、 NumPy配列のロードを参照してください。

すべての入力データがメモリに収まる場合、それらからDatasetを作成する最も簡単な方法は、それらをtf.Tensorオブジェクトに変換し、 Dataset.from_tensor_slices()を使用することDataset.from_tensor_slices()

 train, test = tf.keras.datasets.fashion_mnist.load_data()
 
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

 images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
 
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

Pythonジェネレーターの使用

tf.data.Datasetとして簡単に取り込むことができるもう1つの一般的なデータソースは、Pythonジェネレーターです。

 def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
 
 for n in count(5):
  print(n)
 
0
1
2
3
4

Dataset.from_generatorコンストラクターは、Pythonジェネレーターを完全に機能するtf.data.Datasetます。

コンストラクターは、イテレーターではなく呼び出し可能オブジェクトを入力として受け取ります。これにより、ジェネレータが最後に達したときにジェネレータを再起動できます。これはオプションのargs引数を取り、それは呼び出し可能オブジェクトの引数として渡されます。

output_typesための引数が必要であるtf.data構築tf.Graph内部、およびグラフのエッジは必要tf.dtype

 ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
 
 for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
 
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes引数は必須ではありませが、多くのテンソルフロー操作が不明なランクのテンソルをサポートしていないため、強くお勧めします。特定の軸の長さが不明または可変の場合は、 output_shapes Noneに設定します。

また、 output_shapesoutput_typesは、他のデータセットメソッドと同じネストルールに従うことに注意することも重要です。

これは両方の側面を示すサンプルジェネレーターです。これは配列のタプルを返します。2番目の配列は長さが不明なベクトルです。

 def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
 
 for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
 
0 : [-1.978  -1.0531  0.1959 -2.1618]
1 : [ 1.9185 -0.1874  0.5084]
2 : [0.1441 0.3987 0.7737 0.9266 1.5057 0.9151]
3 : [ 0.681  -0.6155 -0.1231 -0.2429  0.6892  1.2571 -1.7588 -1.6575 -0.5375]
4 : [-0.5567  1.5298  0.7242  0.2213]
5 : [ 1.5572 -0.6856]
6 : [-1.0965 -0.336   1.2405  0.6006]

最初の出力はint32 、2番目はfloat32です。

最初の項目はスカラーの形状()で、2番目の項目は長さ、形状が不明のベクトル(None,)

 ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series
 
<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

これで、通常のtf.data.Datasetように使用できます。可変形状のデータセットをバッチ処理する場合は、 Dataset.padded_batchを使用する必要があることにDataset.padded_batch

 ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
 
[ 2 10 18  3  6 15 25 23  0  4]

[[ 1.2665 -0.6274  0.4076  1.0146  0.      0.      0.      0.    ]
 [ 0.8091 -0.0683 -0.1464  0.2734  0.7461 -0.1009  0.      0.    ]
 [-0.9381  1.5075  0.      0.      0.      0.      0.      0.    ]
 [ 1.5705  0.4438  0.      0.      0.      0.      0.      0.    ]
 [-0.4692 -1.8328 -2.2838  0.7418  0.0172 -0.3547 -1.4502 -1.2786]
 [-1.574   0.      0.      0.      0.      0.      0.      0.    ]
 [-0.9274  1.4758  0.      0.      0.      0.      0.      0.    ]
 [-0.5043  0.7066  0.9599 -1.2986  0.      0.      0.      0.    ]
 [ 0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.4893 -0.6937  0.      0.      0.      0.      0.      0.    ]]

より現実的な例については、ラッピングしてみてくださいpreprocessing.image.ImageDataGenerator通りtf.data.Dataset

最初にデータをダウンロードします。

 flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
 
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 2s 0us/step

画像を作成しますimage.ImageDataGenerator

 img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
 
 images, labels = next(img_gen.flow_from_directory(flowers))
 
Found 3670 images belonging to 5 classes.

 print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
 
float32 (32, 256, 256, 3)
float32 (32, 5)

 ds = tf.data.Dataset.from_generator(
    img_gen.flow_from_directory, args=[flowers], 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds
 
<FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, tf.float32)>

TFRecordデータの使用

エンドツーエンドの例については、 TFRecordのロードを参照してください。

tf.data APIはさまざまなファイル形式をサポートしているため、メモリに収まらない大きなデータセットを処理できます。たとえば、TFRecordファイル形式は、多くのTensorFlowアプリケーションがデータのトレーニングに使用する単純なレコード指向のバイナリ形式です。 tf.data.TFRecordDatasetクラスを使用すると、入力パイプラインの一部として1つ以上のTFRecordファイルのコンテンツをストリーミングできます。

以下は、French Street Name Signs(FSNS)のテストファイルを使用した例です。

 # Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
 
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 0s 0us/step

filenames引数TFRecordDatasetいずれかの文字列、文字列のリスト、またはすることができイニシャライザtf.Tensor文字列の。したがって、トレーニングと検証のために2つのファイルセットがある場合は、ファイル名を入力引数として、データセットを生成するファクトリメソッドを作成できます。

 dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
 
<TFRecordDatasetV2 shapes: (), types: tf.string>

多くのTensorFlowプロジェクトは、TFRecordファイルでシリアル化されたtf.train.Exampleレコードを使用します。これらは、検査する前にデコードする必要があります。

 raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
 
bytes_list {
  value: "Rue Perreyon"
}

テキストデータの使用

エンドツーエンドの例については、 テキストのロードを参照してください。

多くのデータセットは、1つ以上のテキストファイルとして配布されます。 tf.data.TextLineDatasetは、1つ以上のテキストファイルから行を抽出する簡単な方法を提供します。 1つ以上のファイル名を指定すると、 TextLineDatasetはそれらのファイルの行ごとに1つの文字列値要素を生成します。

 directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
 
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step

 dataset = tf.data.TextLineDataset(file_paths)
 

最初のファイルの最初の数行は次のとおりです。

 for line in dataset.take(5):
  print(line.numpy())
 
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

ファイル間で行をDataset.interleave使用しDataset.interleave 。これにより、ファイルのシャッフルが簡単になります。以下は、各翻訳の1行目、2行目、3行目です。

 files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())
 

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

デフォルトでは、 TextLineDatasetは各ファイルのすべての行を生成します。たとえば、ファイルがヘッダー行で始まっている場合やコメントが含まれている場合、これは望ましくない場合があります。これらの行は、 Dataset.skip()またはDataset.filter()変換を使用して削除できます。ここでは、最初の行をスキップしてから、生存者のみを見つけるようにフィルタリングします。

 titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
 
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step

 for line in titanic_lines.take(10):
  print(line.numpy())
 
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'

 def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
 
 for line in survivors.take(10):
  print(line.numpy())
 
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

CSVデータの使用

その他の例については、 CSVファイルのロードおよびPandas DataFramesのロードを参照してください。

CSVファイル形式は、表形式のデータをプレーンテキストで格納するための一般的な形式です。

例えば:

 titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
 
 df = pd.read_csv(titanic_file, index_col=None)
df.head()
 

データがメモリに収まる場合、同じDataset.from_tensor_slicesメソッドが辞書で機能するため、このデータを簡単にインポートできます。

 titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
 
  'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

よりスケーラブルなアプローチは、必要に応じてディスクからロードすることです。

tf.dataモジュールは、 RFC 4180に準拠する1つ以上のCSVファイルからレコードを抽出するメソッドを提供します

experimental.make_csv_dataset関数は、csvファイルのセットを読み取るための高レベルのインターフェイスです。列タイプの推論や、バッチ処理やシャッフルなど、他の多くの機能をサポートしており、使用が簡単です。

 titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
 
 for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
 
'survived': [0 1 0 0]
features:
  'sex'               : [b'male' b'female' b'male' b'male']
  'age'               : [28. 42. 43. 21.]
  'n_siblings_spouses': [0 0 0 0]
  'parch'             : [0 0 0 0]
  'fare'              : [47.1    13.      8.05    8.6625]
  'class'             : [b'First' b'Second' b'Third' b'Third']
  'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
  'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Southampton']
  'alone'             : [b'y' b'y' b'y' b'y']

列のサブセットのみが必要な場合は、 select_columns引数を使用できます。

 titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
 
 for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
 
'survived': [1 0 1 1]
  'fare'              : [24.15    0.     13.8583 53.1   ]
  'class'             : [b'Third' b'Second' b'Second' b'First']

より細かい制御を提供する下位レベルのexperimental.CsvDatasetクラスもあります。列タイプの推論はサポートしていません。代わりに、各列のタイプを指定する必要があります。

 titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string] 
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
 
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

一部の列が空の場合、この低レベルのインターフェースを使用すると、列タイプの代わりにデフォルト値を提供できます。

 %%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
 
Writing missing.csv

 # Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
 
<MapDataset shapes: (4,), types: tf.int32>
 for line in dataset:
  print(line.numpy())
 
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

デフォルトでは、 CsvDatasetはファイルのすべての行のすべての列を生成します。これは、たとえば、ファイルが無視されるべきヘッダー行で始まっている場合や、一部の列が入力に不要な場合など、望ましくない場合があります。これらのラインおよびフィールドを用いて除去することができheaderselect_colsそれぞれ引数。

 # Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
 
<MapDataset shapes: (2,), types: tf.int32>
 for line in dataset:
  print(line.numpy())
 
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

ファイルのセットを消費する

ファイルのセットとして配布される多くのデータセットがあり、各ファイルは例です。

 flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

 

ルートディレクトリには、各クラスのディレクトリが含まれています。

 for item in flowers_root.glob("*"):
  print(item.name)
 
sunflowers
daisy
LICENSE.txt
roses
tulips
dandelion

各クラスディレクトリ内のファイルは例です。

 list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
 
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/7243478942_30bf542a2d_m.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/4525067924_177ea3bfb4.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/7002703410_3e97b29da5_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/daisy/6299910262_336309ffa5_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6140661443_bb48344226.jpg'

tf.io.read_file関数を使用してデータを読み取り、パスからラベルを抽出し、 (image, label)ペアを返します。

 def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)
 
 for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
 
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xe2\x0cXICC_PROFILE\x00\x01\x01\x00\x00\x0cHLino\x02\x10\x00\x00mntrRGB XYZ \x07\xce\x00\x02\x00\t\x00\x06\x001\x00\x00acspMSFT\x00\x00\x00\x00IEC sRGB\x00\x00\x00\x00\x00\x00'

b'tulips'

データセット要素のバッチ処理

単純なバッチ処理

バッチ処理の最も単純な形式は、データセットのn連続した要素を単一の要素にスタックします。 Dataset.batch()変換は、これを正確に実行し、 tf.stack()演算子と同じ制約を要素の各コンポーネントに適用します。つまり、各コンポーネントiについて 、すべての要素はまったく同じ形状のテンソルを持つ必要があります。

 inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
 
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

一方でtf.data試みは、形状情報を伝播するために、デフォルトの設定Dataset.batch未知のバッチサイズでの結果は、最後のバッチが満杯ではないかもしれないので。形状のNone注意してください。

 batched_dataset
 
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>

drop_remainder引数を使用して最後のバッチを無視し、完全な形状の伝播を取得します。

 batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
 
<BatchDataset shapes: ((7,), (7,)), types: (tf.int64, tf.int64)>

パディング付きテンソルのバッチ処理

上記のレシピは、すべて同じサイズのテンソルで機能します。ただし、多くのモデル(たとえば、シーケンスモデル)は、さまざまなサイズ(たとえば、異なる長さのシーケンス)を持つ可能性のある入力データを処理します。このケースを処理するために、 Dataset.padded_batch変換を使用すると、パディングされる可能性のある1つ以上の次元を指定することにより、異なる形状のテンソルをバッチ処理できます。

 dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()

 
[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]


Dataset.padded_batch変換を使用すると、各コンポーネントの各次元に異なるパディングを設定できます。これは、可変長(上記の例ではNoneで示されます)または一定長です。パディング値をオーバーライドすることもできます。デフォルトは0です。

トレーニングワークフロー

複数のエポックを処理する

tf.data APIは、同じデータの複数のエポックを処理する2つの主要な方法を提供します。

複数のエポックでデータセットを反復処理する最も簡単な方法は、 Dataset.repeat()変換を使用することです。最初に、タイタニックデータのデータセットを作成します。

 titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
 
 def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')
 

引数なしでDataset.repeat()変換を適用すると、入力が無期限に繰り返されます。

Dataset.repeat変換は、1つのエポックの終わりと次のエポックの始まりを通知することなく、引数を連結します。このため、 Dataset.batch後に適用されるDataset.repeatは、エポック境界にまたがるバッチを生成します。

 titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)
 

png

明確なエポック分離が必要な場合は、繰り返しの前にDataset.batch配置しDataset.batch

 titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)
 

png

各エポックの最後にカスタム計算を実行する(たとえば、統計を収集する)場合は、各エポックでデータセットの反復を再開するのが最も簡単です。

 epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
 
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

入力データをランダムにシャッフルする

Dataset.shuffle()トランスフォーメーションは、固定サイズのバッファーを維持し、そのバッファーからランダムに均一に次の要素を選択します。

効果を確認できるように、データセットにインデックスを追加します。

 lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
 
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.string)>

buffer_sizeが100で、バッチサイズが20であるため、最初のバッチには、インデックスが120を超える要素が含まれていません。

 n,line_batch = next(iter(dataset))
print(n.numpy())
 
[ 63  48 101  12 103  52   6  39   4   9  93  91   5  86  79  64  95  33
 102  50]

同様にDataset.batchへの相対的な順序Dataset.repeat事項。

Dataset.shuffleは、シャッフルバッファーが空になるまでエポックの終了をDataset.shuffleしません。したがって、繰り返しの前に配置されたシャッフルは、次のエポックに移動する前に、あるエポックのすべての要素を表示します。

 dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
 
Here are the item ID's near the epoch boundary:

[613 609 624 553 608 583 493 617 611 610]
[217 508 579 601 319 616 606 549 618 623]
[416 567 404 622 283 458 503 602]
[ 87  68  56  16   6  62   1  89  58 106]
[98 80 43 10 67 44 19 34 13 57]

 shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
 
<matplotlib.legend.Legend at 0x7fe0a00a1d68>

png

しかし、シャッフルの前に繰り返して、エポック境界を混ぜ合わせます。

 dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
 
Here are the item ID's near the epoch boundary:

[440  15   8 599 567  18 550   5  19  17]
[ 12 501 571 473 466  21 531 596 580 555]
[  3 573  38 563  25 416 595  29  46 602]
[485 566 561  16 331 615 386  28 609  41]
[611 622 575  10 589  61 598 527  52  35]
[ 55 597  42  23  13  47  11 505  68 582]
[612 613  75  43   7 392  74 452  82 509]
[  9  44  62 491  71 343  51 590  60  98]
[  6  95 619  86 625 537 617  85 465   0]
[ 88  27  92 101 109 111 104  24  36 113]
[103 118  79  53  70  40 121 100  65  33]
[562 588 124 125  64  84  83  67 610 130]
[  4 142 131  90 518 129 143 112   2 551]
[377  91 140  76  50  48 526 553 156 591]
[105 128  69 114  93 520 154  56 145 115]

 repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
 
<matplotlib.legend.Legend at 0x7fe0582fbb70>

png

データの前処理

Dataset.map(f)変換は、指定された関数fを入力データセットの各要素に適用することにより、新しいデータセットを生成します。これは、関数型プログラミング言語のリスト(およびその他の構造)に一般的に適用されるmap()関数に基づいていmap() 。関数fかかりtf.Tensor入力における単一の要素を表すオブジェクトを、そして返しtf.Tensorそれが新しいデータセット内の単一の要素を表現するオブジェクト。その実装では、標準のTensorFlow演算を使用して、ある要素を別の要素に変換します。

このセクションでは、 Dataset.map()一般的な使用例について説明します。

画像データのデコードとサイズ変更

実際の画像データでニューラルネットワークをトレーニングする場合、固定サイズにバッチ処理できるように、異なるサイズの画像を共通のサイズに変換する必要があることがよくあります。

花のファイル名のデータセットを再構築します。

 list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
 

データセット要素を操作する関数を記述します。

 # Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(filename, os.sep)
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label
 

それが機能することをテストします。

 file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)
 

png

データセットにマッピングします。

 images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)
 

png

png

任意のPythonロジックを適用する

パフォーマンス上の理由から、可能な限りデータの前処理にはTensorFlowオペレーションを使用してください。ただし、入力データを解析するときに外部Pythonライブラリを呼び出すと便利な場合があります。 Dataset.map()トランスフォーメーションでtf.py_function()オペレーションを使用できます。

たとえば、ランダムな回転を適用する場合、 tf.imageモジュールにはtf.image.rot90しかないため、画像の拡張にはあまり役立ちません。

tf.py_functionを示すtf.py_function 、代わりにscipy.ndimage.rotate関数を使用してみてください:

 import scipy.ndimage as ndimage

def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
 
 image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

で、この機能を使用するにはDataset.map同じ警告を持つように適用Dataset.from_generator 、あなたが関数を適用する場合、リターン形状や種類を記述する必要があります。

 def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label
 
 rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

png

tf.Exampleプロトコルバッファメッセージの解析

多くの入力パイプラインは、TFRecord形式からtf.train.Exampleプロトコルバッファーメッセージを抽出します。各tf.train.Exampleレコードには1つ以上の「特徴」が含まれており、入力パイプラインは通常、これらの特徴をテンソルに変換します。

 fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
 
<TFRecordDatasetV2 shapes: (), types: tf.string>

tf.train.Example外でtf.data.Datasetを使用して、データを理解できます。

 raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])
 

png

 raw_example = next(iter(dataset))
 
 def tf_parse(eg):
  example = tf.io.parse_example(
      eg[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
 
 img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
 
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...

 decoded = dataset.map(tf_parse)
decoded
 
<MapDataset shapes: ((), ()), types: (tf.string, tf.string)>
 image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
 
TensorShape([10])

時系列ウィンドウ

エンドツーエンドの時系列の例については、 時系列予測を参照してください。

時系列データは、時間軸をそのままにして編成されることがよくあります。

簡単なDataset.rangeを使用して、 Dataset.rangeことを示します。

 range_ds = tf.data.Dataset.range(100000)
 

通常、この種のデータに基づくモデルでは、連続したタイムスライスが必要になります。

最も簡単な方法は、データをバッチ処理することです。

batchを使用する

 batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
 
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

または、密度の高い予測を1ステップ先に行うには、フィーチャとラベルを相互に1ステップずつシフトします。

 def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
 
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

固定オフセットの代わりにウィンドウ全体を予測するには、バッチを2つの部分に分割できます。

 batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Take the first 5 steps
          batch[-5:])   # take the remainder

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
 
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

1つのバッチの機能と別のバッチのラベル​​の間の重複を許可するには、 Dataset.zip使用しDataset.zip

 feature_length = 10
label_length = 5

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])

predict_5_steps = tf.data.Dataset.zip((features, labels))

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
 
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22 23 24]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32 33 34]

windowを使用する

Dataset.batchを使用するとDataset.batchますが、より細かい制御が必要になる場合があります。 Dataset.windowメソッドは完全な制御を提供しますが、いくつかの注意が必要です: Datasets Datasetを返します。詳細については、 データセットの構造を参照してください。

 window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
 
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>

Dataset.flat_mapメソッドは、データセットのデータセットを受け取り、それを単一のデータセットにフラット化できます。

  for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
 
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fe0582dbbf8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7fe0582dbbf8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9 

ほとんどすべての場合、最初にデータセットを.batchする必要があります。

 def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
 
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

これで、 shift引数が各ウィンドウの移動量を制御していることがわかります。

これをまとめると、次の関数を書くことができます。

 def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows

 
 ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
 
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

その後、以前と同様にラベルを抽出するのは簡単です。

 dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
 
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

リサンプリング

クラスが非常に不均衡なデータセットを使用する場合は、データセットをリサンプリングすることができます。 tf.dataには、これを行うための2つのメソッドが用意されています。クレジットカード詐欺のデータセットは、この種の問題の良い例です。

 zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
 
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69156864/69155632 [==============================] - 10s 0us/step

 creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])
 

今、クラスの分布をチェックしてください、それは非常に歪んでいます:

 def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
 
 counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
 
[0.996 0.004]

不均衡なデータセットを使用したトレーニングの一般的なアプローチは、データセットのバランスを取ることです。 tf.dataは、このワークフローを可能にするいくつかのメソッドが含まれています。

データセットのサンプリング

データセットをリサンプリングする1つの方法は、 sample_from_datasetsを使用することsample_from_datasets 。これは、クラスごとに個別のdata.Datasetがある場合により適しています。

ここでは、フィルターを使用してクレジットカード詐欺データからそれらを生成します。

 negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())
 
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fe0a01fd1e0> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7fe0a01fd1e0> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fe058159620> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7fe058159620> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

 for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
 
[1 1 1 1 1 1 1 1 1 1]

tf.data.experimental.sample_from_datasetsを使用するには、データセットとそれぞれの重みを渡します。

 balanced_ds = tf.data.experimental.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)
 

これで、データセットは50/50の確率で各クラスの例を生成します。

 for features, labels in balanced_ds.take(10):
  print(labels.numpy())
 
[0 0 0 1 0 0 1 1 0 0]
[1 1 0 0 1 1 0 0 0 1]
[1 1 0 0 0 0 0 0 1 1]
[1 0 0 1 1 0 0 0 1 0]
[1 1 0 0 0 1 1 0 0 1]
[0 0 1 1 1 0 0 1 1 0]
[0 0 0 1 0 0 0 1 1 1]
[1 1 0 1 0 0 1 0 1 1]
[0 1 0 0 1 1 0 0 0 1]
[0 1 0 0 1 0 0 1 1 0]

拒否のリサンプリング

上記のexperimental.sample_from_datasetsアプローチの1つの問題は、クラスごとに個別のtf.data.Datasetが必要なことです。 Dataset.filterを使用するとDataset.filterますが、すべてのデータが2回ロードされます。

data.experimental.rejection_resample関数をデータセットに適用すると、データを1回だけロードしながら、データのバランスを再調整できます。要素は、バランスをとるためにデータセットから削除されます。

data.experimental.rejection_resampleclass_func引数を取ります。このclass_funcは各データセット要素に適用され、バランスをとるために例がどのクラスに属するかを決定するために使用されます。

creditcard_dsの要素はすでに(features, label)ペアです。したがって、 class_funcこれらのラベルを返す必要があるだけです。

 def class_func(features, label):
  return label
 

リサンプラーには、ターゲット分布と、オプションで初期分布推定も必要です。

 resampler = tf.data.experimental.rejection_resample(
    class_func, target_dist=[0.5, 0.5], initial_dist=fractions)
 

リサンプラーは個々の例を扱うため、リサンプラーを適用する前にデータセットのunbatchする必要があります。

 resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/experimental/ops/resampling.py:156: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:


リサンプラーは、 class_func出力からclass_func (class, example)ペアを返します。この場合、 exampleはすでに(feature, label)ペアだったので、 mapを使用してラベルの余分なコピーをドロップします。

 balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)
 

これで、データセットは50/50の確率で各クラスの例を生成します。

 for features, labels in balanced_ds.take(10):
  print(labels.numpy())
 
[1 0 1 1 1 1 0 1 1 1]
[0 0 1 1 1 0 1 0 1 1]
[1 0 0 1 0 0 0 0 0 1]
[1 1 1 1 1 1 0 1 1 1]
[1 1 0 1 0 0 0 0 1 0]
[1 0 0 0 1 0 1 0 1 0]
[0 0 0 0 0 0 1 0 0 0]
[0 0 0 1 1 0 0 1 0 1]
[0 0 1 1 1 1 0 0 1 1]
[0 0 1 1 0 1 0 1 1 0]

イテレータチェックポイント

Tensorflowはチェックポイントの取得をサポートしているため、トレーニングプロセスが再開すると、最新のチェックポイントを復元して進行状況のほとんどを回復できます。モデル変数のチェックポイント設定に加えて、データセットイテレータの進行状況をチェックポイント設定することもできます。これは、大きなデータセットがあり、再起動するたびに最初からデータセットを開始したくない場合に役立ちます。ただし、 shuffleprefetchなどの変換では、反復子内の要素をバッファリングする必要があるため、反復子のチェックポイントが大きくなる可能性があることに注意してください。

イテレーターをチェックポイントに含めるには、イテレーターをtf.train.Checkpointコンストラクターに渡します。

 range_ds = tf.data.Dataset.range(20)

iterator = iter(range_ds)
ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)

print([next(iterator).numpy() for _ in range(5)])

save_path = manager.save()

print([next(iterator).numpy() for _ in range(5)])

ckpt.restore(manager.latest_checkpoint)

print([next(iterator).numpy() for _ in range(5)])
 
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[5, 6, 7, 8, 9]

高レベルAPIの使用

tf.keras

tf.keras APIは、機械学習モデルの作成と実行の多くの側面を簡素化します。その.fit()および.evaluate()および.predict() APIは、入力としてデータセットをサポートします。以下は、簡単なデータセットとモデルの設定です。

 train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
 
 fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=['accuracy'])
 

(feature, label)ペアのデータセットを渡すだけで、 Model.fitModel.evaluate必要なことはすべてModel.evaluateです。

 model.fit(fmnist_train_ds, epochs=2)
 
Epoch 1/2
WARNING:tensorflow:Layer flatten is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1875/1875 [==============================] - 4s 2ms/step - loss: 0.6031 - accuracy: 0.7937
Epoch 2/2
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4620 - accuracy: 0.8416

<tensorflow.python.keras.callbacks.History at 0x7fe13f9ed3c8>

たとえばDataset.repeat()呼び出して無限データセットを渡す場合は、 steps_per_epoch引数も渡す必要があります。

 model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
 
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4050 - accuracy: 0.8672
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4077 - accuracy: 0.8703

<tensorflow.python.keras.callbacks.History at 0x7fe0ca13edd8>

評価の場合、評価ステップの数を渡すことができます。

 loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
 
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4474 - accuracy: 0.8439
Loss : 0.4474281072616577
Accuracy : 0.843916654586792

長いデータセットの場合、評価するステップ数を設定します。

 loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
 
10/10 [==============================] - 0s 2ms/step - loss: 0.5262 - accuracy: 0.8156
Loss : 0.5262183547019958
Accuracy : 0.815625011920929

Model.predict呼び出す場合、ラベルは必要ありません。

 predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
 
(320, 10)

ただし、ラベルを含むデータセットを渡した場合、ラベルは無視されます。

 result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
 
(320, 10)

tf.estimator

使用するにはDataset中でinput_fntf.estimator.Estimator 、単に返すDatasetからinput_fnとフレームワークはあなたのためにその要素を消費するの世話をします。例えば:

 import tensorflow_datasets as tfds

def train_input_fn():
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches
 
 embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
age = tf.feature_column.numeric_column('age')
 
 import tempfile
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
 
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpefmfuc4o', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

 model = model.train(input_fn=train_input_fn, steps=100)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1666: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:540: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:144: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpefmfuc4o/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpefmfuc4o/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.58668363.

 result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
 
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-07-23T01:23:29Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpefmfuc4o/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.83507s
INFO:tensorflow:Finished evaluation at 2020-07-23-01:23:30
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.675, accuracy_baseline = 0.58125, auc = 0.71750116, auc_precision_recall = 0.6480325, average_loss = 0.64111984, global_step = 100, label/mean = 0.41875, loss = 0.64111984, precision = 0.85714287, prediction/mean = 0.30204886, recall = 0.26865673
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpefmfuc4o/model.ckpt-100
accuracy : 0.675
accuracy_baseline : 0.58125
auc : 0.71750116
auc_precision_recall : 0.6480325
average_loss : 0.64111984
label/mean : 0.41875
loss : 0.64111984
precision : 0.85714287
prediction/mean : 0.30204886
recall : 0.26865673
global_step : 100

 for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
 
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpefmfuc4o/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.5965]
logistic : [0.3551]
probabilities : [0.6449 0.3551]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']