Сохраните дату! Google I / O возвращается 18-20 мая Зарегистрируйтесь сейчас
Эта страница переведена с помощью Cloud Translation API.
Switch to English

Сегментация изображения

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

В этом руководстве основное внимание уделяется задаче сегментации изображений с использованием модифицированного U-Net .

Что такое сегментация изображений?

До сих пор вы видели классификацию изображений, в которой задача сети - присвоить метку или класс входному изображению. Однако предположим, что вы хотите знать, где находится объект на изображении, форму этого объекта, какой пиксель принадлежит какому объекту и т. Д. В этом случае вам нужно сегментировать изображение, т. Е. Каждый пиксель изображения дали ярлык. Таким образом, задача сегментации изображения - обучить нейронную сеть выводить пиксельную маску изображения. Это помогает понять изображение на гораздо более низком уровне, то есть на уровне пикселей. Сегментация изображений имеет множество применений в медицинской визуализации, беспилотных автомобилях и спутниковой съемке, и это лишь некоторые из них.

Набор данных, который будет использоваться в этом руководстве, - это набор данных Oxford-IIIT Pet Dataset , созданный Parkhi et al . Набор данных состоит из изображений, соответствующих им меток и пиксельных масок. Маски - это в основном метки для каждого пикселя. Каждому пикселю дается одна из трех категорий:

  • Класс 1: пиксель, принадлежащий питомцу.
  • Класс 2: Пиксель, граничащий с домашним животным.
  • Класс 3: Ничего из вышеперечисленного / Окружающий пиксель.
pip install -q git+https://github.com/tensorflow/examples.git
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix

import tensorflow_datasets as tfds

from IPython.display import clear_output
import matplotlib.pyplot as plt

Загрузите набор данных Oxford-IIIT Pets

Набор данных уже включен в наборы данных TensorFlow, все, что нужно сделать, это загрузить его. Маски сегментации включены в версию 3+.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

Следующий код выполняет простое увеличение переворачивания изображения. Кроме того, изображение нормализуется до [0,1]. Наконец, как упоминалось выше, пиксели в маске сегментации помечены как {1, 2, 3}. Для удобства вычтем 1 из маски сегментации, в результате чего получатся следующие метки: {0, 1, 2}.

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

Набор данных уже содержит необходимые разбиения теста и обучения, поэтому давайте продолжим использовать то же разбиение.

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test = dataset['test'].map(load_image_test)
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

Давайте посмотрим на пример изображения и соответствующую маску из набора данных.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()
for image, mask in train.take(1):
  sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

PNG

Определите модель

Используемая здесь модель представляет собой модифицированную U-Net. U-Net состоит из кодировщика (понижающий дискретизатор) и декодера (повышающий дискретизатор). Чтобы изучить надежные функции и уменьшить количество обучаемых параметров, предварительно обученную модель можно использовать в качестве кодировщика. Таким образом, кодировщиком для этой задачи будет предварительно обученная модель MobileNetV2, промежуточные выходы которой будут использоваться, а декодером будет блок повышающей дискретизации, уже реализованный в примерах TensorFlow в учебнике Pix2pix .

Причина вывода трех каналов заключается в том, что для каждого пикселя есть три возможных метки. Думайте об этом как о мультиклассификации, когда каждый пиксель классифицируется на три класса.

OUTPUT_CHANNELS = 3

Как уже упоминалось, кодировщиком будет предварительно обученная модель MobileNetV2, подготовленная и готовая к использованию в tf.keras.applications . Кодировщик состоит из определенных выходных данных промежуточных слоев модели. Учтите, что кодировщик не будет обучаться в процессе обучения.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Декодер / повышающий дискретизатор - это просто серия блоков повышающей дискретизации, реализованная в примерах TensorFlow.

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]
def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Обучите модель

Теперь осталось только скомпилировать и обучить модель. Используемые здесь потери - это losses.SparseCategoricalCrossentropy(from_logits=True) . Причина использования этой функции потерь заключается в том, что сеть пытается присвоить каждому пикселю метку, точно так же, как прогнозирование нескольких классов. В истинной маске сегментации каждый пиксель имеет либо {0,1,2}. Сеть здесь выводит три канала. По сути, каждый канал пытается научиться предсказывать класс и losses.SparseCategoricalCrossentropy(from_logits=True) - это рекомендуемые потери для такого сценария. Используя выходной сигнал сети, метка, присвоенная пикселю, представляет собой канал с наивысшим значением. Это то, что делает функция create_mask.

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Взглянем на получившуюся архитектуру модели:

tf.keras.utils.plot_model(model, show_shapes=True)

PNG

Давайте попробуем модель, чтобы увидеть, что она предсказывает, перед тренировкой.

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()

PNG

Давайте посмотрим, как модель улучшается в процессе обучения. Для выполнения этой задачи ниже определена функция обратного вызова.

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])

PNG

Sample Prediction after epoch 20
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

PNG

Делать предсказания

Сделаем некоторые прогнозы. В целях экономии времени количество эпох оставлено небольшим, но вы можете установить его больше, чтобы получить более точные результаты.

show_predictions(test_dataset, 3)

PNG

PNG

PNG

Следующие шаги

Теперь, когда вы понимаете, что такое сегментация изображения и как она работает, вы можете попробовать это руководство с различными выходными данными промежуточного слоя или даже с другой предварительно обученной моделью. Вы также можете испытать себя, попробовав задачу по маскировке изображений Carvana, размещенную на Kaggle.

Вы также можете увидеть API обнаружения объектов Tensorflow для другой модели, которую вы можете переобучить на своих собственных данных.