tf.dataを使って画像をロードする

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

このチュートリアルでは、'tf.data' を使って画像データセットをロードする簡単な例を示します。

このチュートリアルで使用するデータセットは、クラスごとに別々のディレクトリに別れた形で配布されています。

設定

import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE

データセットのダウンロードと検査

画像の取得

訓練を始める前に、ネットワークに認識すべき新しいクラスを教えるために画像のセットが必要です。最初に使うためのクリエイティブ・コモンズでライセンスされた花の画像のアーカイブを作成してあります。

import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)
/home/kbuilder/.keras/datasets/flower_photos

218MB をダウンロードすると、花の画像のコピーが使えるようになっているはずです。

for item in data_root.iterdir():
  print(item)
/home/kbuilder/.keras/datasets/flower_photos/sunflowers
/home/kbuilder/.keras/datasets/flower_photos/daisy
/home/kbuilder/.keras/datasets/flower_photos/LICENSE.txt
/home/kbuilder/.keras/datasets/flower_photos/roses
/home/kbuilder/.keras/datasets/flower_photos/tulips
/home/kbuilder/.keras/datasets/flower_photos/dandelion

import random
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
image_count
3670
all_image_paths[:10]
['/home/kbuilder/.keras/datasets/flower_photos/roses/6950609394_c53b8c6ac0_m.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/sunflowers/15042911059_b6153d94e7_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/tulips/16862349256_0a1f91ab53.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/roses/8462246855_1bdfee7478.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/16970837587_4a9d8500d7.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4746643626_02b2d056a2_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/roses/8671682526_7058143c99.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/daisy/5739768868_9f982684f9_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/sunflowers/22405887122_75eda1872f_m.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/roses/5223191368_01aedb6547_n.jpg']

画像の検査

扱っている画像について知るために、画像のいくつかを見てみましょう。

import os
attributions = (data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
attributions = [line.split(' CC-BY') for line in attributions]
attributions = dict(attributions)
import IPython.display as display

def caption_image(image_path):
  image_rel = pathlib.Path(image_path).relative_to(data_root)
  return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])

for n in range(3):
  image_path = random.choice(all_image_paths)
  display.display(display.Image(image_path))
  print(caption_image(image_path))
  print()

jpeg

Image (CC BY 2.0)  by katrien berckmoes


jpeg

Image (CC BY 2.0)  by DncnH


jpeg

Image (CC BY 2.0)  by Don Graham


各画像のラベルの決定

ラベルを一覧してみます。

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_names
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

ラベルにインデックスを割り当てます。

label_to_index = dict((name, index) for index,name in enumerate(label_names))
label_to_index
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

ファイルとラベルのインデックスの一覧を作成します。

all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]

print("First 10 labels indices: ", all_image_labels[:10])
First 10 labels indices:  [2, 3, 4, 2, 1, 3, 2, 0, 3, 2]

画像の読み込みと整形

TensorFlow には画像を読み込んで処理するために必要なツールが備わっています。

img_path = all_image_paths[0]
img_path
'/home/kbuilder/.keras/datasets/flower_photos/roses/6950609394_c53b8c6ac0_m.jpg'

以下は生のデータです。

img_raw = tf.io.read_file(img_path)
print(repr(img_raw)[:100]+"...")
<tf.Tensor: shape=(), dtype=string, numpy=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00...

画像のテンソルにデコードします。

img_tensor = tf.image.decode_image(img_raw)

print(img_tensor.shape)
print(img_tensor.dtype)
(228, 240, 3)
<dtype: 'uint8'>

モデルに合わせてリサイズします。

img_final = tf.image.resize(img_tensor, [192, 192])
img_final = img_final/255.0
print(img_final.shape)
print(img_final.numpy().min())
print(img_final.numpy().max())

(192, 192, 3)
0.0
1.0

このあと使用するために、簡単な関数にまとめます。

def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize(image, [192, 192])
  image /= 255.0  # normalize to [0,1] range

  return image
def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)
import matplotlib.pyplot as plt

image_path = all_image_paths[0]
label = all_image_labels[0]

plt.imshow(load_and_preprocess_image(img_path))
plt.grid(False)
plt.xlabel(caption_image(img_path))
plt.title(label_names[label].title())
print()


png

tf.data.Datasetの構築

画像のデータセット

tf.data.Dataset を構築するもっとも簡単な方法は、from_tensor_slices メソッドを使うことです。

文字列の配列をスライスすると、文字列のデータセットが出来上がります。

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

shapestypes は、データセット中のそれぞれのアイテムの内容を示しています。この場合には、バイナリ文字列のスカラーのセットです。

print(path_ds)
<TensorSliceDataset shapes: (), types: tf.string>

preprocess_image をファイルパスのデータセットにマップすることで、画像を実行時にロードし整形する新しいデータセットを作成します。

image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
import matplotlib.pyplot as plt

plt.figure(figsize=(8,8))
for n,image in enumerate(image_ds.take(4)):
  plt.subplot(2,2,n+1)
  plt.imshow(image)
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  plt.xlabel(caption_image(all_image_paths[n]))
  plt.show()

png

png

png

png

(image, label)のペアのデータセット

おなじ from_tensor_slices メソッドを使ってラベルのデータセットを作ることができます。

label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
for label in label_ds.take(10):
  print(label_names[label.numpy()])
roses
sunflowers
tulips
roses
dandelion
sunflowers
roses
daisy
sunflowers
roses

これらのデータセットはおなじ順番なので、zip することで (image, label) というペアのデータセットができます。

image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

新しいデータセットの shapestypes は、それぞれのフィールドを示すシェイプと型のタプルです。

print(image_label_ds)
<ZipDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int64)>

注: all_image_labelsall_image_paths のような配列がある場合、 tf.data.dataset.Dataset.zip メソッドの代わりとなるのは、配列のペアをスライスすることです。

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

# The tuples are unpacked into the positional arguments of the mapped function
# タプルは展開され、マップ関数の位置引数に割り当てられます
def load_and_preprocess_from_path_label(path, label):
  return load_and_preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds
<MapDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>

基本的な訓練手法

このデータセットを使ってモデルの訓練を行うには、データが

  • よくシャッフルされ
  • バッチ化され
  • 限りなく繰り返され
  • バッチが出来るだけ早く利用できる

ことが必要です。

これらの特性は tf.data APIを使えば簡単に付け加えることができます。

BATCH_SIZE = 32

# シャッフルバッファのサイズをデータセットとおなじに設定することで、データが完全にシャッフルされる
# ようにできます。
ds = image_label_ds.shuffle(buffer_size=image_count)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
# `prefetch`を使うことで、モデルの訓練中にバックグラウンドでデータセットがバッチを取得できます。
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>

注意すべきことがいくつかあります。

  1. 順番が重要です。

    • .repeat の前に .shuffle すると、エポックの境界を越えて要素がシャッフルされます。(ほかの要素がすべて出現する前に2回出現する要素があるかもしれません)
    • .batch の後に .shuffle すると、バッチの順番がシャッフルされますが、要素がバッチを越えてシャッフルされることはありません。
  2. 完全なシャッフルのため、 buffer_size をデータセットとおなじサイズに設定しています。データセットのサイズ未満の場合、値が大きいほど良くランダム化されますが、より多くのメモリーを使用します。

  3. シャッフルバッファがいっぱいになってから要素が取り出されます。そのため、大きな buffer_sizeDataset を使い始める際の遅延の原因になります。

  4. シャッフルされたデータセットは、シャッフルバッファが完全に空になるまでデータセットが終わりであることを伝えません。 .repeat によって Dataset が再起動されると、シャッフルバッファが一杯になるまでもう一つの待ち時間が発生します。

最後の問題は、 tf.data.Dataset.apply メソッドを、融合された tf.data.experimental.shuffle_and_repeat 関数と組み合わせることで対処できます。

ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
WARNING:tensorflow:From <ipython-input-31-4dc713bd4d84>:2: shuffle_and_repeat (from tensorflow.python.data.experimental.ops.shuffle_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.

<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>

データセットをモデルにつなぐ

tf.keras.applicationsからMobileNet v2のコピーを取得します。

これを簡単な転移学習のサンプルに使用します。

MobileNetの重みを訓練不可に設定します。

mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)
mobile_net.trainable=False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_192_no_top.h5
9412608/9406464 [==============================] - 1s 0us/step

このモデルは、入力が [-1,1] の範囲に正規化されていることを想定しています。

help(keras_applications.mobilenet_v2.preprocess_input)
...
This function applies the "Inception" preprocessing which converts
the RGB values from [0, 255] to [-1, 1] 
...

このため、データをMobileNetモデルに渡す前に、入力を[0,1]の範囲から[-1,1]の範囲に変換する必要があります。

def change_range(image,label):
    return 2*image-1, label

keras_ds = ds.map(change_range)

MobileNetは画像ごとに 6x6 の特徴量の空間を返します。

バッチを1つ渡してみましょう。

# シャッフルバッファがいっぱいになるまで、データセットは何秒かかかります。
image_batch, label_batch = next(iter(keras_ds))
feature_map_batch = mobile_net(image_batch)
print(feature_map_batch.shape)
(32, 6, 6, 1280)

MobileNet をラップしたモデルを作り、出力層である tf.keras.layers.Dense の前に、tf.keras.layers.GlobalAveragePooling2D で空間の軸にそって平均値を求めます。

model = tf.keras.Sequential([
    mobile_net,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(len(label_names))])

期待したとおりの形状の出力が得られます。

logit_batch = model(image_batch).numpy()

print("min logit:", logit_batch.min())
print("max logit:", logit_batch.max())
print()

print("Shape:", logit_batch.shape)
min logit: -3.0309844
max logit: 2.1899369

Shape: (32, 5)

訓練手法を記述するためにモデルをコンパイルします。

model.compile(optimizer=tf.keras.optimizers.Adam(), 
              loss='sparse_categorical_crossentropy',
              metrics=["accuracy"])

訓練可能な変数は2つ、全結合層の weightsbias です。

len(model.trainable_variables) 
2
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_192 (Functi (None, 6, 6, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

モデルを訓練します。

普通は、エポックごとの本当のステップ数を指定しますが、ここではデモの目的なので3ステップだけとします。

steps_per_epoch=tf.math.ceil(len(all_image_paths)/BATCH_SIZE).numpy()
steps_per_epoch
115.0
model.fit(ds, epochs=1, steps_per_epoch=3)
3/3 [==============================] - 0s 29ms/step - loss: 8.9401 - accuracy: 0.2083

<tensorflow.python.keras.callbacks.History at 0x7fce761a32b0>

性能

注:このセクションでは性能の向上に役立ちそうな簡単なトリックをいくつか紹介します。詳しくは、Input Pipeline Performance を参照してください。

上記の単純なパイプラインは、エポックごとにそれぞれのファイルを一つずつ読み込みます。これは、CPU を使ったローカルでの訓練では問題になりませんが、GPU を使った訓練では十分ではなく、いかなる分散訓練でも使うべきではありません。

調査のため、まず、データセットの性能をチェックする簡単な関数を定義します。

import time
default_timeit_steps = 2*steps_per_epoch+1

def timeit(ds, steps=default_timeit_steps):
  overall_start = time.time()
  # Fetch a single batch to prime the pipeline (fill the shuffle buffer),
  # before starting the timer
  it = iter(ds.take(steps+1))
  next(it)

  start = time.time()
  for i,(images,labels) in enumerate(it):
    if i%10 == 0:
      print('.',end='')
  print()
  end = time.time()

  duration = end-start
  print("{} batches: {} s".format(steps, duration))
  print("{:0.5f} Images/s".format(BATCH_SIZE*steps/duration))
  print("Total time: {}s".format(end-overall_start))

現在のデータセットの性能は次のとおりです。

ds = image_label_ds.apply(
    tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 14.232337236404419 s
519.38061 Images/s
Total time: 21.03696084022522s

キャッシュ

tf.data.Dataset.cache を使うと、エポックを越えて計算結果を簡単にキャッシュできます。特に、データがメモリに収まるときには効果的です。

ここでは、画像が前処理(デコードとリサイズ)された後でキャッシュされます。

ds = image_label_ds.cache()
ds = ds.apply(
    tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 0.6558306217193604 s
11271.20289 Images/s
Total time: 7.143255949020386s

メモリキャッシュを使う際の欠点のひとつは、実行の都度キャッシュを再構築しなければならないことです。このため、データセットがスタートするたびにおなじだけ起動のための遅延が発生します。

timeit(ds)
........................
231.0 batches: 0.600313663482666 s
12313.56281 Images/s
Total time: 0.6126508712768555s

データがメモリに収まらない場合には、キャッシュファイルを使用します。

ds = image_label_ds.cache(filename='./cache.tf-data')
ds = ds.apply(
    tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(1)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 2.9384658336639404 s
2515.59842 Images/s
Total time: 11.71722674369812s

キャッシュファイルには、キャッシュを再構築することなくデータセットを再起動できるという利点もあります。2回めがどれほど早いか見てみましょう。

timeit(ds)
........................
231.0 batches: 2.3221585750579834 s
3183.24514 Images/s
Total time: 3.611947774887085s

TFRecord ファイル

生の画像データ

TFRecord ファイルは、バイナリの大きなオブジェクトのシーケンスを保存するための単純なフォーマットです。複数のサンプルをおなじファイルに詰め込むことで、TensorFlow は複数のサンプルを一度に読み込むことができます。これは、特に GCS のようなリモートストレージサービスを使用する際の性能にとって重要です。

最初に、生の画像データから TFRecord ファイルを構築します。

image_ds = tf.data.Dataset.from_tensor_slices(all_image_paths).map(tf.io.read_file)
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(image_ds)

次に、TFRecord ファイルを読み込み、以前定義した preprocess_image 関数を使って画像のデコード/リフォーマットを行うデータセットを構築します。

image_ds = tf.data.TFRecordDataset('images.tfrec').map(preprocess_image)

これを、前に定義済みのラベルデータセットと zip し、期待どおりの (image,label) のペアを得ます。

ds = tf.data.Dataset.zip((image_ds, label_ds))
ds = ds.apply(
    tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int64)>
timeit(ds)
........................
231.0 batches: 14.166187047958374 s
521.80590 Images/s
Total time: 20.94697093963623s

これは、cache バージョンよりも低速です。前処理をキャッシュしていないからです。

シリアライズしたテンソル

前処理を TFRecord ファイルに保存するには、前やったように前処理した画像のデータセットを作ります。

paths_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = paths_ds.map(load_and_preprocess_image)
image_ds
<MapDataset shapes: (192, 192, 3), types: tf.float32>

.jpeg 文字列のデータセットではなく、これはテンソルのデータセットです。

これを TFRecord ファイルにシリアライズするには、まず、テンソルのデータセットを文字列のデータセットに変換します。

ds = image_ds.map(tf.io.serialize_tensor)
ds
<MapDataset shapes: (), types: tf.string>
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(ds)

前処理をキャッシュしたことにより、データは TFRecord ファイルから非常に効率的にロードできます。テンソルを使用する前にデシリアライズすることを忘れないでください。

ds = tf.data.TFRecordDataset('images.tfrec')

def parse(x):
  result = tf.io.parse_tensor(x, out_type=tf.float32)
  result = tf.reshape(result, [192, 192, 3])
  return result

ds = ds.map(parse, num_parallel_calls=AUTOTUNE)
ds
<ParallelMapDataset shapes: (192, 192, 3), types: tf.float32>

次にラベルを追加し、以前とおなじような標準的な処理を適用します。

ds = tf.data.Dataset.zip((ds, label_ds))
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int64)>
timeit(ds)
........................
231.0 batches: 1.8715846538543701 s
3949.59426 Images/s
Total time: 2.6737515926361084s