MLコミュニティデーは11月9日です! TensorFlow、JAXからの更新のために私たちに参加し、より多くの詳細をご覧ください

画像セグメンテーション

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示 ノートブックをダウンロード

このチュートリアルは、修飾用い、画像分割のタスクに焦点を当ててU-ネット

画像セグメンテーションとは何ですか?

画像分類タスクでは、ネットワークは各入力画像にラベル(またはクラス)を割り当てます。ただし、そのオブジェクトの形状、どのピクセルがどのオブジェクトに属しているかなどを知りたいとします。この場合、画像の各ピクセルにクラスを割り当てます。このタスクはセグメンテーションとして知られています。セグメンテーションモデルは、画像に関するより詳細な情報を返します。画像セグメンテーションは、いくつか例を挙げると、医療画像、自動運転車、衛星画像に多くの用途があります。

このチュートリアル用途オックスフォード、IIITペットデータセットParkhiら、2012 )。データセットは37のペットの品種の画像で構成され、品種ごとに200の画像があります(トレーニングとテストの分割でそれぞれ最大100)。各画像には、対応するラベルとピクセル単位のマスクが含まれています。マスクは、各ピクセルのクラスラベルです。各ピクセルには、次の3つのカテゴリのいずれかが与えられます。

  • クラス1:ペットに属するピクセル。
  • クラス2:ペットに隣接するピクセル。
  • クラス3:上記のいずれでもない/周囲のピクセル。
pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing 

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

Oxford-IIITPetsデータセットをダウンロードする

データセットがあるTensorFlowデータセットから入手できます。セグメンテーションマスクはバージョン3以降に含まれています。

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

また、画像の色の値はに正規化されている[0,1]の範囲。最後に、前述のように、セグメンテーションマスクのピクセルには{1、2、3}のいずれかのラベルが付けられます。便宜上、セグメンテーションマスクから1を引くと、ラベルは{0、1、2}になります。

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

データセットにはすでに必要なトレーニングとテストの分割が含まれているため、引き続き同じ分割を使用してください。

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

次のクラスは、画像をランダムに反転することによって単純な拡張を実行します。行くイメージ増強もっと学ぶためにチュートリアル。

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = preprocessing.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = preprocessing.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

入力をバッチ処理した後に拡張を適用して、入力パイプラインを構築します。

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

データセットから画像の例とそれに対応するマスクを視覚化します。

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

png

png

2021-10-12 01:25:33.460499: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

モデルを定義する

ここで使用されているモデルが変更されているU-Netの。 U-Netは、エンコーダー(ダウンサンプラー)とデコーダー(アップサンプラー)で構成されます。堅牢な機能を学習し、トレーニング可能なパラメーターの数を減らすために、事前にトレーニングされたモデルであるMobileNetV2-をエンコーダーとして使用します。デコーダのために、あなたはすでに実装されてアップサンプリングブロック、使用するpix2pixの例としては、レポTensorFlowの例を。 (チェックアウトpix2pix:条件付きGANを持つ画像間の変換。ノートブックでチュートリアルを)

上述したように、エンコーダを調製し、で使用する準備ができているpretrained MobileNetV2モデルなりtf.keras.applications 。エンコーダーは、モデルの中間層からの特定の出力で構成されます。エンコーダはトレーニングプロセス中にトレーニングされないことに注意してください。

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

デコーダー/アップサンプラーは、TensorFlowの例で実装された一連のアップサンプルブロックです。

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]
def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

最後の層上のフィルタの数は数に設定されていることに注意してくださいoutput_channels 。これは、クラスごとに1つの出力チャネルになります。

モデルをトレーニングする

あとは、モデルをコンパイルしてトレーニングするだけです。

これは、マルチクラス分類問題であるので、使用tf.keras.losses.CategoricalCrossentropyで損失関数をfrom_logitsへの引数セットTrueラベルがスカラ整数の代わりに、各クラスの各画素のスコアのベクトルであるため、。

推論を実行する場合、ピクセルに割り当てられたラベルは、最も高い値を持つチャネルです。これは何であるcreate_mask機能が行っています。

OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

結果のモデルアーキテクチャをざっと見てみましょう。

tf.keras.utils.plot_model(model, show_shapes=True)

png

モデルを試して、トレーニングの前にモデルが何を予測するかを確認してください。

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()

png

以下に定義されているコールバックは、トレーニング中にモデルがどのように改善されるかを観察するために使用されます。

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

png

Sample Prediction after epoch 20
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

png

予測を行います

ここで、いくつかの予測を行います。時間を節約するために、エポックの数は少なく保たれましたが、より正確な結果を得るには、これを高く設定することができます。

show_predictions(test_batches, 3)

png

png

png

オプション:不均衡なクラスとクラスの重み

セマンティックセグメンテーションデータセットは非常に不均衡である可能性があります。つまり、特定のクラスのピクセルが他のクラスのピクセルよりも画像の内部に多く存在する可能性があります。セグメンテーションの問題はピクセルごとの分類の問題として扱うことができるため、これを説明するために損失関数を重み付けすることで不均衡の問題に対処できます。これは、この問題に対処するためのシンプルでエレガントな方法です。参照してください不均衡データに分類もっと学ぶためにチュートリアル。

するにはあいまいさを避けるためModel.fitサポートしていませんclass_weight 3+寸法の入力の引数を。

try:
  model_history = model.fit(train_batches, epochs=EPOCHS,
                            steps_per_epoch=STEPS_PER_EPOCH,
                            class_weight = {0:2.0, 1:2.0, 2:1.0})
  assert False
except Exception as e:
  print(f"Expected {type(e).__name__}: {e}")
Expected ValueError: `class_weight` not supported for 3+ dimensional targets.

したがって、この場合、自分で均等化を実装​​する必要があります。あなたは、サンプルの重みを使用して、これをやる:に加えて(data, label)のペア、 Model.fitまた受け入れ(data, label, sample_weight)トリプルを。

Model.fit伝播sample_weightも受け入れ損失および評価指標にsample_weight引数を。サンプルの重量は、削減ステップの前にサンプルの値で乗算されます。例えば:

label = [0,0]
prediction = [[-3., 0], [-3, 0]] 
sample_weight = [1, 10] 

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
                                               reduction=tf.losses.Reduction.NONE)
loss(label, prediction, sample_weight).numpy()
array([ 3.0485873, 30.485874 ], dtype=float32)

だから、このチュートリアルのサンプルの重みを作るためにあなたが取る関数必要(data, label)のペアとリターンを(data, label, sample_weight)トリプルを。ここsample_weight各画素のクラスの重みを含む1チャネル画像です。

最も単純な実装では、インデックスとしてラベルを使用することですclass_weightリスト:

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

結果のデータセット要素には、それぞれ3つの画像が含まれます。

train_batches.map(add_sample_weights).element_spec
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))

これで、この重み付けされたデータセットでモデルをトレーニングできます。

weighted_model = unet_model(OUTPUT_CLASSES)
weighted_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
weighted_model.fit(
    train_batches.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
10/10 [==============================] - 3s 42ms/step - loss: 0.3370 - accuracy: 0.5909
<keras.callbacks.History at 0x7f2898263390>

次のステップ

画像セグメンテーションとは何か、そしてそれがどのように機能するかを理解したので、このチュートリアルをさまざまな中間層出力、またはさまざまな事前トレーニング済みモデルで試すことができます。あなたも試してみることによって自分自身に挑戦することがCarvana Kaggleでホストされている画像のマスキングチャレンジを。

あなたも見てみたいことがありTensorflow物体検出APIを使用すると、独自のデータに再教育することができます別のモデルのために。 Pretrainedモデルがで利用可能であるTensorFlowハブ