![]() | ![]() | ![]() | ![]() |
概要
このチュートリアルでは、データ拡張について説明します。これは、画像の回転などのランダムな(ただし現実的な)変換を適用することにより、トレーニングセットの多様性を高める手法です。
次の2つの方法でデータ拡張を適用する方法を学習します。
- tf.keras.layers.Resizing、
tf.keras.layers.Rescaling
、tf.keras.layers.RandomFlip
、tf.keras.layers.RandomRotation
などのtf.keras.layers.Resizing
前処理レイヤーを使用します。 -
tf.image.flip_left_right
、tf.image.rgb_to_grayscale
、tf.image.adjust_brightness
、tf.image.central_crop
、tf.image.stateless_random*
などのtf.image
メソッドを使用します。
設定
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
データセットをダウンロードする
このチュートリアルでは、 tf_flowersデータセットを使用します。便宜上、 TensorFlowデータセットを使用してデータセットをダウンロードします。データをインポートする他の方法について知りたい場合は、画像の読み込みチュートリアルをご覧ください。
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
花のデータセットには5つのクラスがあります。
num_classes = metadata.features['label'].num_classes
print(num_classes)
5
データセットから画像を取得し、それを使用してデータ拡張を示しましょう。
get_label_name = metadata.features['label'].int2str
image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2022-01-26 05:09:18.712477: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Keras前処理レイヤーを使用する
サイズ変更と再スケーリング
Keras前処理レイヤーを使用して、画像のサイズを一貫した形状に変更したり( tf.keras.layers.Resizing
を使用)、ピクセル値を再スケーリングしたり( tf.keras.layers.Rescaling
を使用)できます。
IMG_SIZE = 180
resize_and_rescale = tf.keras.Sequential([
layers.Resizing(IMG_SIZE, IMG_SIZE),
layers.Rescaling(1./255)
])
これらのレイヤーを画像に適用した結果を視覚化できます。
result = resize_and_rescale(image)
_ = plt.imshow(result)
ピクセルが[0, 1]
の範囲にあることを確認します。
print("Min and max pixel values:", result.numpy().min(), result.numpy().max())
Min and max pixel values: 0.0 1.0
データ拡張
tf.keras.layers.RandomFlip
やtf.keras.layers.RandomRotation
などのKeras前処理レイヤーをデータ拡張にも使用できます。
いくつかの前処理レイヤーを作成し、それらを同じ画像に繰り返し適用してみましょう。
data_augmentation = tf.keras.Sequential([
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.2),
])
# Add the image to a batch.
image = tf.expand_dims(image, 0)
plt.figure(figsize=(10, 10))
for i in range(9):
augmented_image = data_augmentation(image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_image[0])
plt.axis("off")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
tf.keras.layers.RandomContrast
、 tf.keras.layers.RandomCrop
、 tf.keras.layers.RandomZoom
など、データ拡張に使用できるさまざまな前処理レイヤーがあります。
Keras前処理レイヤーを使用するための2つのオプション
これらの前処理レイヤーを使用するには2つの方法がありますが、重要なトレードオフがあります。
オプション1:前処理レイヤーをモデルの一部にする
model = tf.keras.Sequential([
# Add the preprocessing layers you created earlier.
resize_and_rescale,
data_augmentation,
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# Rest of your model.
])
この場合、注意すべき2つの重要なポイントがあります。
データ拡張はデバイス上で実行され、残りのレイヤーと同期して実行され、GPUアクセラレーションの恩恵を受けます。
model.save
を使用してモデルをエクスポートすると、前処理レイヤーがモデルの残りの部分と一緒に保存されます。後でこのモデルをデプロイすると、(レイヤーの構成に従って)イメージが自動的に標準化されます。これにより、そのロジックサーバー側を再実装する手間を省くことができます。
オプション2:前処理レイヤーをデータセットに適用する
aug_ds = train_ds.map(
lambda x, y: (resize_and_rescale(x, training=True), y))
このアプローチでは、 Dataset.map
を使用して、拡張画像のバッチを生成するデータセットを作成します。この場合:
- データ拡張はCPU上で非同期的に行われ、非ブロッキングです。以下に示す
Dataset.prefetch
を使用して、GPUでのモデルのトレーニングをデータ前処理とオーバーラップさせることができます。 - この場合、
Model.save
を呼び出すと、前処理レイヤーはモデルとともにエクスポートされません。モデルを保存する前にモデルにアタッチするか、サーバー側で再実装する必要があります。トレーニング後、エクスポートする前に前処理レイヤーをアタッチできます。
最初のオプションの例は、画像分類チュートリアルにあります。ここで2番目のオプションを示しましょう。
前処理レイヤーをデータセットに適用します
以前に作成したKeras前処理レイヤーを使用して、トレーニング、検証、およびテストのデータセットを構成します。また、並列読み取りとバッファー付きプリフェッチを使用してデータセットのパフォーマンスを構成し、I / Oがブロックされることなくディスクからバッチを生成します。 ( tf.data APIガイドを使用したパフォーマンスの向上でデータセットのパフォーマンスの詳細を確認してください。)
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE
def prepare(ds, shuffle=False, augment=False):
# Resize and rescale all datasets.
ds = ds.map(lambda x, y: (resize_and_rescale(x), y),
num_parallel_calls=AUTOTUNE)
if shuffle:
ds = ds.shuffle(1000)
# Batch all datasets.
ds = ds.batch(batch_size)
# Use data augmentation only on the training set.
if augment:
ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y),
num_parallel_calls=AUTOTUNE)
# Use buffered prefetching on all datasets.
return ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)
プレースホルダー18モデルをトレーニングする
完全を期すために、準備したデータセットを使用してモデルをトレーニングします。
シーケンシャルモデルは、3つの畳み込みブロック( tf.keras.layers.Conv2D
)で構成され、それぞれに最大プーリングレイヤー( tf.keras.layers.MaxPooling2D
)があります。完全に接続されたレイヤー( tf.keras.layers.Dense
)があり、その上に128ユニットがあり、ReLUアクティブ化関数( 'relu'
)によってアクティブ化されます。このモデルは精度が調整されていません(目標はメカニズムを示すことです)。
model = tf.keras.Sequential([
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
tf.keras.optimizers.Adam
オプティマイザーとtf.keras.losses.SparseCategoricalCrossentropy
損失関数を選択します。各トレーニングエポックのトレーニングと検証の精度を表示するには、 metrics
引数をModel.compile
に渡します。
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
いくつかのエポックのトレーニング:
epochs=5
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
Epoch 1/5 92/92 [==============================] - 13s 110ms/step - loss: 1.2768 - accuracy: 0.4622 - val_loss: 1.0929 - val_accuracy: 0.5640 Epoch 2/5 92/92 [==============================] - 3s 25ms/step - loss: 1.0579 - accuracy: 0.5749 - val_loss: 0.9711 - val_accuracy: 0.6349 Epoch 3/5 92/92 [==============================] - 3s 26ms/step - loss: 0.9677 - accuracy: 0.6291 - val_loss: 0.9764 - val_accuracy: 0.6431 Epoch 4/5 92/92 [==============================] - 3s 25ms/step - loss: 0.9150 - accuracy: 0.6468 - val_loss: 0.8906 - val_accuracy: 0.6431 Epoch 5/5 92/92 [==============================] - 3s 25ms/step - loss: 0.8636 - accuracy: 0.6604 - val_loss: 0.8233 - val_accuracy: 0.6730
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
プレースホルダー23l10n-プレースホルダー12/12 [==============================] - 5s 14ms/step - loss: 0.7922 - accuracy: 0.6948 Accuracy 0.6948229074478149
カスタムデータ拡張
カスタムデータ拡張レイヤーを作成することもできます。
チュートリアルのこのセクションでは、そのための2つの方法を示します。
- まず、
tf.keras.layers.Lambda
レイヤーを作成します。これは、簡潔なコードを書くための良い方法です。 - 次に、サブクラス化を介して新しいレイヤーを作成します。これにより、より詳細な制御が可能になります。
両方のレイヤーは、ある程度の確率に従って、画像の色をランダムに反転させます。
def random_invert_img(x, p=0.5):
if tf.random.uniform([]) < p:
x = (255-x)
else:
x
return x
def random_invert(factor=0.5):
return layers.Lambda(lambda x: random_invert_img(x, factor))
random_invert = random_invert()
plt.figure(figsize=(10, 10))
for i in range(9):
augmented_image = random_invert(image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_image[0].numpy().astype("uint8"))
plt.axis("off")
プレースホルダー27l10n-プレースホルダー2022-01-26 05:09:53.045204: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.045264: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.045312: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.045369: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.045418: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.045467: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.045511: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module 2022-01-26 05:09:53.047630: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
次に、サブクラス化してカスタムレイヤーを実装します。
class RandomInvert(layers.Layer):
def __init__(self, factor=0.5, **kwargs):
super().__init__(**kwargs)
self.factor = factor
def call(self, x):
return random_invert_img(x)
_ = plt.imshow(RandomInvert()(image)[0])
これらのレイヤーは両方とも、上記のオプション1と2で説明したように使用できます。
tf.imageを使用する
上記のKeras前処理ユーティリティは便利です。ただし、より細かく制御するために、 tf.data
とtf.image
を使用して独自のデータ拡張パイプラインまたはレイヤーを作成できます。 ( TensorFlowアドオンイメージ:操作とTensorFlow I / O:色空間変換も確認することをお勧めします。)
花のデータセットは以前にデータ拡張を使用して構成されていたため、再インポートして最初からやり直してみましょう。
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
使用する画像を取得します。
image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2022-01-26 05:09:59.918847: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
次の関数を使用して、元の画像と拡張画像を並べて視覚化して比較してみましょう。
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
データ拡張
画像を反転する
tf.image.flip_left_right
を使用して、画像を垂直または水平に反転します。
flipped = tf.image.flip_left_right(image)
visualize(image, flipped)
画像をグレースケール
tf.image.rgb_to_grayscale
を使用して画像をグレースケールできます。
grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()
画像を飽和させる
彩度係数を指定して、 tf.image.adjust_saturation
で画像を彩度します。
saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)
画像の明るさを変更する
明るさ係数を指定して、 tf.image.adjust_brightness
で画像の明るさを変更します。
bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)
画像を中央でトリミング
tf.image.central_crop
を使用して、画像を中央から目的の画像部分まで切り抜きます。
cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image, cropped)
画像を回転させる
tf.image.rot90
を使用して画像を90度回転します。
rotated = tf.image.rot90(image)
visualize(image, rotated)
ランダム変換
画像にランダムな変換を適用すると、データセットの一般化と拡張にさらに役立ちます。現在のtf.image
は、このようなランダムな画像操作(ops)を8つ提供しています。
-
tf.image.stateless_random_brightness
-
tf.image.stateless_random_contrast
-
tf.image.stateless_random_crop
-
tf.image.stateless_random_flip_left_right
-
tf.image.stateless_random_flip_up_down
-
tf.image.stateless_random_hue
-
tf.image.stateless_random_jpeg_quality
-
tf.image.stateless_random_saturation
これらのランダムイメージ操作は純粋に機能的です。出力は入力にのみ依存します。これにより、高性能で決定論的な入力パイプラインでの使用が簡単になります。各ステップでseed
値を入力する必要があります。同じseed
が与えられると、呼び出された回数に関係なく、同じ結果が返されます。
次のセクションでは、次のことを行います。
- ランダムな画像操作を使用して画像を変換する例を確認してください。
- ランダム変換をトレーニングデータセットに適用する方法を示します。
画像の明るさをランダムに変更します
明るさ係数とseed
を指定して、 tf.image.stateless_random_brightness
を使用してimage
の明るさをランダムに変更します。輝度係数は[-max_delta, max_delta)
の範囲でランダムに選択され、指定されたseed
に関連付けられます。
for i in range(3):
seed = (i, 0) # tuple of size (2,)
stateless_random_brightness = tf.image.stateless_random_brightness(
image, max_delta=0.95, seed=seed)
visualize(image, stateless_random_brightness)
画像のコントラストをランダムに変更します
コントラスト範囲とseed
を指定して、 tf.image.stateless_random_contrast
を使用してimage
のコントラストをランダムに変更します。コントラスト範囲は[lower, upper]
の間隔でランダムに選択され、指定されたseed
に関連付けられます。
for i in range(3):
seed = (i, 0) # tuple of size (2,)
stateless_random_contrast = tf.image.stateless_random_contrast(
image, lower=0.1, upper=0.9, seed=seed)
visualize(image, stateless_random_contrast)
画像をランダムにトリミング
ターゲットsize
とseed
を指定して、 tf.image.stateless_random_crop
を使用してimage
をランダムにトリミングします。 image
から切り抜かれる部分は、ランダムに選択されたオフセットにあり、指定されたseed
に関連付けられています。
for i in range(3):
seed = (i, 0) # tuple of size (2,)
stateless_random_crop = tf.image.stateless_random_crop(
image, size=[210, 300, 3], seed=seed)
visualize(image, stateless_random_crop)
データセットに拡張を適用する
前のセクションで変更された場合に備えて、最初に画像データセットを再度ダウンロードしてみましょう。
(train_datasets, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
次に、画像のサイズ変更と再スケーリングのためのユーティリティ関数を定義します。この関数は、データセット内の画像のサイズとスケールを統一するために使用されます。
def resize_and_rescale(image, label):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
image = (image / 255.0)
return image, label
また、画像にランダム変換を適用できるaugment
関数を定義しましょう。この関数は、次のステップでデータセットで使用されます。
def augment(image_label, seed):
image, label = image_label
image, label = resize_and_rescale(image, label)
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
# Make a new seed.
new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
# Random crop back to the original size.
image = tf.image.stateless_random_crop(
image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
# Random brightness.
image = tf.image.stateless_random_brightness(
image, max_delta=0.5, seed=new_seed)
image = tf.clip_by_value(image, 0, 1)
return image, label
オプション1:tf.data.experimental.Counterを使用する
tf.data.experimental.Counter
オブジェクト( counter
と呼びましょう)を作成し、 Dataset.zip
を使用してデータセットを(counter, counter)
作成します。これにより、データセット内の各画像が、 counter
に基づいて(形状(2,)
の)一意の値に関連付けられ、後でランダム変換のseed
値としてaugment
関数に渡されるようになります。
# Create a `Counter` object and `Dataset.zip` it together with the training set.
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))
augment
関数をトレーニングデータセットにマッピングします。
train_ds = (
train_ds
.shuffle(1000)
.map(augment, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
val_ds = (
val_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
test_ds = (
test_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
プレースホルダー50オプション2:tf.random.Generatorを使用する
- 初期
seed
値を使用してtf.random.Generator
オブジェクトを作成します。同じジェネレータオブジェクトでmake_seeds
関数を呼び出すと、常に新しい一意のseed
値が返されます。 - 次のラッパー関数を定義します。1)
make_seeds
関数を呼び出します。 2)新しく生成されたseed
値をランダム変換のaugment
関数に渡します。
# Create a generator.
rng = tf.random.Generator.from_seed(123, alg='philox')
# Create a wrapper function for updating seeds.
def f(x, y):
seed = rng.make_seeds(2)[0]
image, label = augment((x, y), seed)
return image, label
ラッパー関数f
をトレーニングデータセットにマップし、 resize_and_rescale
関数を検証セットとテストセットにマップします。
train_ds = (
train_datasets
.shuffle(1000)
.map(f, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
val_ds = (
val_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
test_ds = (
test_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
これらのデータセットを使用して、前に示したようにモデルをトレーニングできるようになりました。
次のステップ
このチュートリアルでは、Keras前処理レイヤーとtf.image
を使用したデータ拡張について説明しました。
- モデル内に前処理レイヤーを含める方法については、画像分類チュートリアルを参照してください。
- 基本的なテキスト分類チュートリアルに示されているように、前処理レイヤーがテキストの分類にどのように役立つかを学ぶことにも興味があるかもしれません。
- このガイドで
tf.data
の詳細を学ぶことができ、ここでパフォーマンスのために入力パイプラインを構成する方法を学ぶことができます。