Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Segmentación de imagen

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este tutorial se centra en la tarea de segmentación de imágenes, utilizando una versión modificada T-Net .

¿Qué es la segmentación de imágenes?

En una tarea de clasificación de imágenes, la red asigna una etiqueta (o clase) a cada imagen de entrada. Sin embargo, suponga que desea conocer la forma de ese objeto, qué píxel pertenece a qué objeto, etc. En este caso, querrá asignar una clase a cada píxel de la imagen. Esta tarea se conoce como segmentación. Un modelo de segmentación devuelve información mucho más detallada sobre la imagen. La segmentación de imágenes tiene muchas aplicaciones en imágenes médicas, automóviles autónomos e imágenes satelitales, por nombrar algunas.

Este tutorial utiliza la Oxford-IIIT conjunto de datos mascotas ( Parkhi et al, 2012 ). El conjunto de datos consta de imágenes de 37 razas de mascotas, con 200 imágenes por raza (~ 100 cada una en las divisiones de entrenamiento y prueba). Cada imagen incluye las etiquetas correspondientes y máscaras de píxeles. Las máscaras son etiquetas de clase para cada píxel. A cada píxel se le asigna una de estas tres categorías:

  • Clase 1: Pixel perteneciente a la mascota.
  • Clase 2: Pixel que bordea a la mascota.
  • Clase 3: Ninguno de los anteriores / un píxel circundante.
pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

Descargue el conjunto de datos de mascotas Oxford-IIIT

El conjunto de datos está disponible en TensorFlow conjuntos de datos . Las máscaras de segmentación se incluyen en la versión 3+.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

Además, los valores de color de la imagen se normalizaron a la [0,1] gama. Finalmente, como se mencionó anteriormente, los píxeles de la máscara de segmentación están etiquetados como {1, 2, 3}. Por conveniencia, reste 1 de la máscara de segmentación, lo que da como resultado etiquetas que son: {0, 1, 2}.

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

El conjunto de datos ya contiene el entrenamiento requerido y las divisiones de prueba, así que continúe usando las mismas divisiones.

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

La siguiente clase realiza un aumento simple volteando una imagen al azar. Ir a la imagen de aumento de tutorial para aprender más.

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

Construya la canalización de entrada, aplicando el Aumento después de agrupar las entradas.

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

Visualice un ejemplo de imagen y su máscara correspondiente del conjunto de datos.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

png

png

2021-10-27 01:31:30.714975: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Definir el modelo

El modelo que se utiliza aquí es una versión modificada T-Net . Un U-Net consta de un codificador (muestreador reducido) y un decodificador (muestreador superior). Para aprender funciones sólidas y reducir la cantidad de parámetros entrenables, utilizará un modelo previamente entrenado, MobileNetV2, como codificador. Para el decodificador, que va a utilizar el bloque upsample, que ya está implementado en el pix2pix ejemplo en el TensorFlow Ejemplos de pases. (Consulte el pix2pix: traducción de imagen a imagen con un GAN condicional tutorial en un cuaderno.)

Como se ha mencionado, el codificador será un modelo MobileNetV2 pretrained que es preparado y listo para usar en tf.keras.applications . El codificador consta de salidas específicas de capas intermedias en el modelo. Tenga en cuenta que el codificador no se entrenará durante el proceso de entrenamiento.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

El decodificador / upsampler es simplemente una serie de bloques de upsample implementados en ejemplos de TensorFlow.

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]
def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Tenga en cuenta que el número de filtros en la última capa se establece en el número de output_channels . Este será un canal de salida por clase.

Entrena el modelo

Ahora, todo lo que queda por hacer es compilar y entrenar el modelo.

Dado que este es un problema de clasificación multiclase, utilice el tf.keras.losses.CategoricalCrossentropy función de pérdida con el from_logits argumento establecido en True , ya que las etiquetas son números enteros escalares en lugar de vectores de puntuaciones para cada píxel de cada clase.

Al ejecutar la inferencia, la etiqueta asignada al píxel es el canal con el valor más alto. Esto es lo que el create_mask función está haciendo.

OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Eche un vistazo rápido a la arquitectura del modelo resultante:

tf.keras.utils.plot_model(model, show_shapes=True)

png

Pruebe el modelo para comprobar lo que predice antes de entrenar.

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()

png

La devolución de llamada definida a continuación se utiliza para observar cómo mejora el modelo mientras se entrena.

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

png

Sample Prediction after epoch 20
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

png

Hacer predicciones

Ahora, haz algunas predicciones. Con el fin de ahorrar tiempo, el número de épocas se mantuvo pequeño, pero puede establecerlo más alto para lograr resultados más precisos.

show_predictions(test_batches, 3)

png

png

png

Opcional: clases y pesos de clase desequilibrados

Los conjuntos de datos de segmentación semántica pueden estar muy desequilibrados, lo que significa que los píxeles de una clase particular pueden estar presentes más dentro de las imágenes que los de otras clases. Dado que los problemas de segmentación pueden tratarse como problemas de clasificación por píxel, puede abordar el problema de desequilibrio sopesando la función de pérdida para tener en cuenta esto. Es una forma simple y elegante de lidiar con este problema. Consulte la clasificación de los datos desequilibrada tutorial para aprender más.

Para evitar la ambigüedad , Model.fit no admite la class_weight argumento para las entradas con más de 3 dimensiones.

try:
  model_history = model.fit(train_batches, epochs=EPOCHS,
                            steps_per_epoch=STEPS_PER_EPOCH,
                            class_weight = {0:2.0, 1:2.0, 2:1.0})
  assert False
except Exception as e:
  print(f"Expected {type(e).__name__}: {e}")
Expected ValueError: `class_weight` not supported for 3+ dimensional targets.

Entonces, en este caso, debe implementar la ponderación usted mismo. Esto lo hará usando ponderaciones de la muestra: Además de (data, label) pares, Model.fit también acepta (data, label, sample_weight) triples.

Model.fit propaga la sample_weight a las pérdidas y las métricas, que también aceptan una sample_weight argumento. El peso de la muestra se multiplica por el valor de la muestra antes del paso de reducción. Por ejemplo:

label = [0,0]
prediction = [[-3., 0], [-3, 0]] 
sample_weight = [1, 10] 

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
                                               reduction=tf.losses.Reduction.NONE)
loss(label, prediction, sample_weight).numpy()
array([ 3.0485873, 30.485874 ], dtype=float32)

Así que para hacer pesos de la muestra para este tutorial se necesita una función que toma un (data, label) par y devuelve un (data, label, sample_weight) triple. Cuando el sample_weight es una imagen 1-canal que contiene el peso de clase para cada píxel.

La implementación más simple posible es utilizar la etiqueta como un índice en una class_weight lista:

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

Los elementos del conjunto de datos resultantes contienen 3 imágenes cada uno:

train_batches.map(add_sample_weights).element_spec
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))

Ahora puede entrenar un modelo en este conjunto de datos ponderados:

weighted_model = unet_model(OUTPUT_CLASSES)
weighted_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
weighted_model.fit(
    train_batches.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
10/10 [==============================] - 3s 42ms/step - loss: 0.3264 - accuracy: 0.5658
<keras.callbacks.History at 0x7f75d076b810>

Próximos pasos

Ahora que comprende qué es la segmentación de imágenes y cómo funciona, puede probar este tutorial con diferentes salidas de capa intermedia o incluso con diferentes modelos previamente entrenados. También puede ponerse a prueba probando la Carvana desafío imagen enmascaramiento alojado en Kaggle.

También es posible que desee ver el API de detección de objetos Tensorflow para otro modelo se puede reciclar en sus propios datos. Pretrained modelos están disponibles en TensorFlow Hub