Сегментация изображения

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

В этом руководстве основное внимание уделяется задаче сегментации изображений с использованием модифицированного U-Net .

Что такое сегментация изображений?

До сих пор вы видели классификацию изображений, в которой задача сети - присвоить метку или класс входному изображению. Однако предположим, что вы хотите знать, где находится объект на изображении, форму этого объекта, какой пиксель какому объекту принадлежит и т. Д. В этом случае вам нужно сегментировать изображение, т. Е. Каждый пиксель изображения является дали ярлык. Таким образом, задача сегментации изображения - обучить нейронную сеть выводить пиксельную маску изображения. Это помогает понять изображение на гораздо более низком уровне, то есть на уровне пикселей. Сегментация изображений имеет множество применений в медицинской визуализации, автономных автомобилях и спутниковой съемке, и это лишь некоторые из них.

Набор данных, который будет использоваться в этом руководстве, - это набор данных Oxford-IIIT Pet Dataset , созданный Parkhi et al . Набор данных состоит из изображений, соответствующих им меток и пиксельных масок. Маски - это в основном метки для каждого пикселя. Каждому пикселю дается одна из трех категорий:

  • Класс 1: пиксель, принадлежащий питомцу.
  • Класс 2: Пиксель, граничащий с домашним животным.
  • Класс 3: Ни один из вышеперечисленных / Окружающий пиксель.
pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix

import tensorflow_datasets as tfds

from IPython.display import clear_output
import matplotlib.pyplot as plt

Загрузите набор данных Oxford-IIIT Pets

Набор данных уже включен в наборы данных TensorFlow, все, что нужно сделать, это загрузить его. Маски сегментации включены в версию 3+.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

Следующий код выполняет простое увеличение переворачивания изображения. Кроме того, изображение нормализуется до [0,1]. Наконец, как упоминалось выше, пиксели в маске сегментации помечены как {1, 2, 3}. Для удобства вычтем 1 из маски сегментации, в результате чего получатся следующие метки: {0, 1, 2}.

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

Набор данных уже содержит необходимые разбиения теста и обучения, поэтому давайте продолжим использовать то же разбиение.

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test = dataset['test'].map(load_image_test)
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

Давайте посмотрим на пример изображения и соответствующую маску из набора данных.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()
for image, mask in train.take(1):
  sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

PNG

Определите модель

Используемая здесь модель представляет собой модифицированную U-Net. U-Net состоит из кодировщика (субдискретизатора) и декодера (повышающего дискретизатора). Чтобы изучить надежные функции и уменьшить количество обучаемых параметров, предварительно обученную модель можно использовать в качестве кодировщика. Таким образом, кодировщиком для этой задачи будет предварительно обученная модель MobileNetV2, промежуточные выходы которой будут использоваться, а декодером будет блок повышающей дискретизации, уже реализованный в примерах TensorFlow в учебнике Pix2pix .

Причина вывода трех каналов заключается в том, что для каждого пикселя есть три возможных метки. Думайте об этом как о мультиклассификации, когда каждый пиксель классифицируется по трем классам.

OUTPUT_CHANNELS = 3

Как уже упоминалось, кодировщиком будет предварительно обученная модель MobileNetV2, подготовленная и готовая к использованию в tf.keras.applications . Кодировщик состоит из определенных выходных данных промежуточных слоев модели. Учтите, что кодировщик не будет обучаться в процессе обучения.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Декодер / повышающий дискретизатор - это просто серия блоков повышающей дискретизации, реализованная в примерах TensorFlow.

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]
def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Обучите модель

Теперь осталось только скомпилировать и обучить модель. Используемые здесь потери - это losses.SparseCategoricalCrossentropy(from_logits=True) . Причина использования этой функции потерь заключается в том, что сеть пытается присвоить каждому пикселю метку, точно так же, как прогнозирование нескольких классов. В истинной маске сегментации каждый пиксель имеет либо {0,1,2}. Сеть здесь выводит три канала. По сути, каждый канал пытается научиться предсказывать класс и losses.SparseCategoricalCrossentropy(from_logits=True) - это рекомендуемые потери для такого сценария. Используя выходной сигнал сети, метка, присвоенная пикселю, представляет собой канал с наивысшим значением. Это то, что делает функция create_mask.

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Взглянем на получившуюся архитектуру модели:

tf.keras.utils.plot_model(model, show_shapes=True)

PNG

Давайте попробуем модель, чтобы увидеть, что она предсказывает, перед тренировкой.

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()

PNG

Давайте посмотрим, как модель улучшается в процессе обучения. Для выполнения этой задачи ниже определена функция обратного вызова.

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])

PNG

Sample Prediction after epoch 20
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

PNG

Делать предсказания

Сделаем некоторые прогнозы. В целях экономии времени количество эпох оставлено небольшим, но вы можете установить его больше, чтобы получить более точные результаты.

show_predictions(test_dataset, 3)

PNG

PNG

PNG

Необязательно: несбалансированные классы и веса классов

Наборы данных семантической сегментации могут быть сильно несбалансированными, что означает, что пиксели определенного класса могут присутствовать внутри изображений больше, чем пиксели других классов. Поскольку проблемы сегментации можно рассматривать как проблемы попиксельной классификации, мы можем решить проблему дисбаланса, взвешивая функцию потерь, чтобы учесть это. Это простой и элегантный способ справиться с этой проблемой. См. Учебник по несбалансированным классам .

Чтобы избежать двусмысленности , Model.fit не поддерживает аргумент class_weight для входов с 3+ размерами.

try:
  model_history = model.fit(train_dataset, epochs=EPOCHS,
                            steps_per_epoch=STEPS_PER_EPOCH,
                            class_weight = {0:2.0, 1:2.0, 2:1.0})
  assert False
except Exception as e:
  print(f"{type(e).__name__}: {e}")
ValueError: `class_weight` not supported for 3+ dimensional targets.

Так что в этом случае вам нужно выполнить взвешивание самостоятельно. Вы сделаете это, используя образцы весов: в дополнение к парам (data, label) , Model.fit также принимает Model.fit (data, label, sample_weight) .

Model.fit распространяет sample_weight на потери и метрики, которые также принимают аргумент sample_weight . Вес образца умножается на значение образца перед этапом уменьшения. Например:

label = [0,0]
prediction = [[-3., 0], [-3, 0]] 
sample_weight = [1, 10] 

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
                                               reduction=tf.losses.Reduction.NONE)
loss(label, prediction, sample_weight).numpy()
array([ 3.0485873, 30.485874 ], dtype=float32)

Итак, чтобы сделать образцы весов для этого руководства, вам понадобится функция, которая принимает пару (data, label) и возвращает тройку (data, label, sample_weight) . Где sample_weight - это одноканальное изображение, содержащее вес класса для каждого пикселя.

Простейшая возможная реализация - использовать метку в качестве индекса в списке class_weight :

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

В результате элементы набора данных содержат по 3 изображения каждый:

train_dataset.map(add_sample_weights).element_spec
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:5049: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))

Теперь вы можете обучить модель на этом взвешенном наборе данных:

weighted_model = unet_model(OUTPUT_CHANNELS)
weighted_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
weighted_model.fit(
    train_dataset.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
10/10 [==============================] - 3s 41ms/step - loss: 0.2712 - accuracy: 0.6973
<tensorflow.python.keras.callbacks.History at 0x7fedb45d17d0>

Следующие шаги

Теперь, когда вы понимаете, что такое сегментация изображения и как она работает, вы можете попробовать это руководство с различными выходными данными промежуточного слоя или даже с другой предварительно обученной моделью. Вы также можете испытать себя, попробовав задачу по маскировке изображений Carvana, размещенную на Kaggle.

Вы также можете увидеть API обнаружения объектов Tensorflow для другой модели, которую вы можете переобучить на своих собственных данных.