![]() |
![]() |
![]() |
![]() |
このチュートリアルでは、修正した U-Net を使用した画像セグメンテーションのタスクに焦点を当てます。
画像セグメンテーションとは
画像分類タスクでは、ネットワークが各入力画像にラベル(またはクラス)を割り当てますが、そのオブジェクトの形状やどのピクセルがどのオブジェクトに属しているかなどを知りたい場合はどうすればよいでしょうか。この場合、画像のピクセルごとにクラスを割り当てようと考えるでしょう。このタスクはセグメンテーションとして知られています。セグメンテーションモデルは、画像に関してはるかに詳細な情報を返します。画像セグメンテーションには、医用イメージング、自動走行車、衛星撮像など、数多くの用途があります。
このチュートリアルでは Oxford-IIIT Pet Dataset(Parkhi et al)を使用します。データセットには、37 種のペット品種と、品種当たり 200 枚の画像(train と test split で約 100 枚ずつ)が含まれます。それぞれの画像には対応するラベルとピクセル方向のマスクが含まれます。マスクは各ピクセルのクラスラベルです。各ピクセルには、次のいずれかのカテゴリが指定されます。
- クラス 1 : ペットに属するピクセル。
- クラス 2 : ペットと境界のピクセル。
- クラス 3: 上記のいずれにも該当しない、または周囲のピクセル。
pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
2022-08-09 02:10:16.698809: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2022-08-09 02:10:17.420226: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-09 02:10:17.420482: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-09 02:10:17.420495: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output
import matplotlib.pyplot as plt
Oxford-IIIT ペットデータセットをダウンロードする
データセットは TensorFlow Datasets から入手できます。セグメンテーションマスクはバージョン 3 以上に含まれています。
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
また、画像の色値は [0,1]
の範囲に正規化されています。最後に、上記で説明したとおり、セグメンテーションのマスクは {1, 2, 3} のいずれかでラベル付けされています。便宜上、セグメンテーションマスクから 1 を減算して、ラベルを {0, 1, 2} としましょう。
def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask
def load_image(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
データセットにはすでに必要となる training と test split が含まれているため、そのまま同じ split を使用します。
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
次のクラスは、画像をランダムにフリップする単純な拡張を実行します。詳細は、画像のデータ拡張チュートリアルをご覧ください。
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
super().__init__()
# both use the same seed, so they'll make the same random changes.
self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
def call(self, inputs, labels):
inputs = self.augment_inputs(inputs)
labels = self.augment_labels(labels)
return inputs, labels
入力パイプラインをビルドし、入力をバッチ処理した後に拡張を適用します。
train_batches = (
train_images
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.repeat()
.map(Augment())
.prefetch(buffer_size=tf.data.AUTOTUNE))
test_batches = test_images.batch(BATCH_SIZE)
画像サンプルと対応するデータセットのマスクを可視化しましょう。
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
for images, masks in train_batches.take(2):
sample_image, sample_mask = images[0], masks[0]
display([sample_image, sample_mask])
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9 Corrupt JPEG data: premature end of data segment
2022-08-09 02:10:25.873411: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
モデルを定義する
ここで使用されるモデルは修正された U-Net です。U-Net には、エンコーダ(ダウンサンプラー)とデコーダ(アップサンプラー)が含まれます。強力な特徴量を理解してトレーニング可能なパラメータ数を減らすため、MobileNetV2 というトレーニング済みモデルをエンコーダとして使用します。デコーダについてはアップサンプルブロックを使用しますが、これは TensorFlow Examples リポジトリの pix2pix の例に実装済みです。(ノートブックの pix2pix: 条件付き GAN による画像から画像への変換チュートリアルをご覧ください。)
前述したように、エンコーダは事前トレーニング済み MobileNetV2 モデルであり、tf.keras.applications で即座に使用できるように準備されています。エンコーダはモデル内の中間レイヤーからの特定の出力で構成されています。トレーニングプロセス中にエンコーダはトレーニングされないので注意してください。
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
# Use the activations of these layers
layer_names = [
'block_1_expand_relu', # 64x64
'block_3_expand_relu', # 32x32
'block_6_expand_relu', # 16x16
'block_13_expand_relu', # 8x8
'block_16_project', # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)
down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5 9406464/9406464 [==============================] - 0s 0us/step
デコーダおよびアップサンプラは、単に TensorFlow の 例に実装されている一連のアップサンプラブロックに過ぎません。
up_stack = [
pix2pix.upsample(512, 3), # 4x4 -> 8x8
pix2pix.upsample(256, 3), # 8x8 -> 16x16
pix2pix.upsample(128, 3), # 16x16 -> 32x32
pix2pix.upsample(64, 3), # 32x32 -> 64x64
]
def unet_model(output_channels:int):
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
# Downsampling through the model
skips = down_stack(inputs)
x = skips[-1]
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])
# This is the last layer of the model
last = tf.keras.layers.Conv2DTranspose(
filters=output_channels, kernel_size=3, strides=2,
padding='same') #64x64 -> 128x128
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
最後のレイヤーのフィルタ数は output_channels
の数に設定されています。これはクラス当たり 1 つの出力チャンネルとなります。
モデルをトレーニングする
では、後は、モデルををコンパイルしてトレーニングするだけです。
これはマルチクラスの分類問題であり、ラベルがクラスごとのピクセルのスコアのベクトルではなくスカラー整数であるため、tf.keras.losses.CategoricalCrossentropy
損失関数を使用して、from_logits
を True
に設定します。
推論を実行すると、ピクセルに割り当てられたラベルが最も高い値を持つチャンネルです。これは、create_mask
関数の作用です。
OUTPUT_CLASSES = 3
model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
結果のモデルアーキテクチャを簡単に見てみましょう。
tf.keras.utils.plot_model(model, show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
トレーニングする前に、モデルが何を予測するかを試してみましょう。
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1):
if dataset:
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0], mask[0], create_mask(pred_mask)])
else:
display([sample_image, sample_mask,
create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()
1/1 [==============================] - 2s 2s/step
以下に定義されるコールバックは、トレーニング中にモデルがどのように改善するかを観測するために使用されます。
class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
clear_output(wait=True)
show_predictions()
print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS
model_history = model.fit(train_batches, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
validation_data=test_batches,
callbacks=[DisplayCallback()])
1/1 [==============================] - 0s 28ms/step
Sample Prediction after epoch 20 57/57 [==============================] - 5s 96ms/step - loss: 0.1837 - accuracy: 0.9186 - val_loss: 0.2724 - val_accuracy: 0.8961
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()
予測する
いくつか予測を行ってみましょう。時間の節約重視の場合はエポック数を少なくしますが、高精度の結果重視の場合はエポック数を増やして設定します。
show_predictions(test_batches, 3)
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 16ms/step
オプション: 不均衡なクラスとクラスの重み
セマンティックセグメンテーションデータセットは非常に不均衡であり、特定のクラスピクセルが他のクラスに比べて画像の内側寄りに存在する可能性があります。セグメンテーションの問題はピクセル単位の分類問題として対応することができるため、不均衡性を考慮して損失関数を重み付けすることで、不均衡の問題に対処することができます。単純かつエレガントにこの問題に取り組むことができます。詳細は、不均衡なデータでの分類のチュートリアルをご覧ください。
あいまいさを回避するために、Model.fit
は 3 次元以上の入力の class_weight
引数をサポートしていません。
try:
model_history = model.fit(train_batches, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
class_weight = {0:2.0, 1:2.0, 2:1.0})
assert False
except Exception as e:
print(f"Expected {type(e).__name__}: {e}")
Expected ValueError: `class_weight` not supported for 3+ dimensional targets.
そのため、この場合、自分で重み付けを実装する必要があります。これにはサンプルの重み付けを使用します。Model.fit
は (data, label)
ペアのほかに (data, label, sample_weight)
トリプレットも受け入れます。
Model.fit
は sample_weight
を損失とメトリクスに伝搬しますが、sample_weight
引数も受け入れます。サンプル重みは、縮小ステップの前にサンプル値で乗算されます。次に例を示します。
label = [0,0]
prediction = [[-3., 0], [-3, 0]]
sample_weight = [1, 10]
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
reduction=tf.losses.Reduction.NONE)
loss(label, prediction, sample_weight).numpy()
array([ 3.0485873, 30.485874 ], dtype=float32)
つまり、このチュートリアルのサンプル重みを作るには、(data, label)
ペアを取って (data, label, sample_weight)
トリプルを返す関数が必要となります。sample_weight
は各ピクセルのクラス重みを含む 1-channel の画像です。
実装を可能な限り単純にするために、ラベルをclass_weight
リストのインデックスとして使用します。
def add_sample_weights(image, label):
# The weights for each class, with the constraint that:
# sum(class_weights) == 1.0
class_weights = tf.constant([2.0, 2.0, 1.0])
class_weights = class_weights/tf.reduce_sum(class_weights)
# Create an image of `sample_weights` by using the label at each pixel as an
# index into the `class weights` .
sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))
return image, label, sample_weights
この結果、データセットの各要素には、3 つの画像が含まれます。
train_batches.map(add_sample_weights).element_spec
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))
これで、この重み付けが付けられたデータセットでモデルをトレーニングできるようになりました。
weighted_model = unet_model(OUTPUT_CLASSES)
weighted_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
weighted_model.fit(
train_batches.map(add_sample_weights),
epochs=1,
steps_per_epoch=10)
10/10 [==============================] - 4s 83ms/step - loss: 0.2885 - accuracy: 0.6721 <keras.callbacks.History at 0x7f6e08d4a0a0>
次のステップ
これで画像セグメンテーションとは何か、それがどのように機能するかについての知識が得られたはずです。このチュートリアルは、異なる中間レイヤー出力や、異なる事前トレーニング済みモデルでも試すことができます。また、Kaggle がホストしている Carvana 画像マスキングチャレンジに挑戦してみることもお勧めです。
Tensorflow Object Detection API を参照して、独自のデータで再トレーニング可能な別のモデルを確認するのも良いでしょう。トレーニング済みのモデルは、TensorFlow Hub にあります。