Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

CycleGAN

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este portátil demuestra la imagen no apareado a la traducción imagen utilizando GAN condicional, como se describe en Unpaired Traducción de imagen a imagen usando Redes Acusatorios ciclo consistente , también conocido como CycleGAN. El artículo propone un método que puede capturar las características de un dominio de imagen y descubrir cómo estas características podrían traducirse en otro dominio de imagen, todo en ausencia de ejemplos de entrenamiento emparejados.

Este portátil supone que está familiarizado con Pix2Pix, que se puede aprender sobre el tutorial Pix2Pix . El código de CycleGAN es similar, la principal diferencia es una función de pérdida adicional y el uso de datos de entrenamiento no emparejados.

CycleGAN utiliza una pérdida de consistencia de ciclo para permitir el entrenamiento sin la necesidad de datos emparejados. En otras palabras, puede traducirse de un dominio a otro sin un mapeo uno a uno entre el dominio de origen y el de destino.

Esto abre la posibilidad de realizar muchas tareas interesantes como mejora de fotografías, coloración de imágenes, transferencia de estilo, etc. Todo lo que necesita es la fuente y el conjunto de datos de destino (que es simplemente un directorio de imágenes).

Imagen de salida 1Imagen de salida 2

Configurar la canalización de entrada

Instalar el tensorflow_examples paquete que permite la importación del generador y el discriminador.

pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

Canalización de entrada

Este tutorial entrena un modelo para traducir imágenes de caballos a imágenes de cebras. Puede encontrar este conjunto de datos y otros similares aquí .

Como se menciona en el documento , se aplican fluctuación aleatoria y el reflejo de la formación de datos. Estas son algunas de las técnicas de aumento de imagen que evitan el sobreajuste.

Esto es similar a lo que se hizo en pix2pix

  • En jittering aleatorio, la imagen se cambia el tamaño de 286 x 286 y luego recortada al azar a 256 x 256 .
  • En la duplicación aleatoria, la imagen se voltea de forma aleatoria horizontalmente, es decir, de izquierda a derecha.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fd518202090>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fd5107cea90>

png

Importar y reutilizar los modelos Pix2Pix

Importe el generador y el discriminador utilizado en Pix2Pix a través de la instalación tensorflow_examples paquete.

La arquitectura modelo utilizado en este tutorial es muy similar al que se utilizó en pix2pix . Algunas de las diferencias son:

Aquí se entrenan 2 generadores (G y F) y 2 discriminadores (X e Y).

  • Generador G aprende a transformar la imagen X de imagen Y . \((G: X -> Y)\)
  • Generador F aprende a transformar la imagen Y a la imagen X . \((F: Y -> X)\)
  • Discriminador D_X aprende a diferenciar entre la imagen X y la imagen generada X ( F(Y) ).
  • Discriminador D_Y aprende a diferenciar entre la imagen Y y la imagen generada Y ( G(X) ).

Modelo Cyclegan

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

Funciones de pérdida

En CycleGAN, no hay datos emparejados para entrenar en, por lo tanto, no hay ninguna garantía de que la entrada x y el objetivo y par son significativos durante el entrenamiento. Por lo tanto, para lograr que la red aprenda el mapeo correcto, los autores proponen la pérdida de consistencia del ciclo.

La pérdida discriminador y la pérdida del generador son similares a los utilizados en pix2pix .

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

La consistencia del ciclo significa que el resultado debe estar cerca de la entrada original. Por ejemplo, si uno traduce una oración del inglés al francés y luego la vuelve a traducir del francés al inglés, la oración resultante debe ser la misma que la oración original.

En la pérdida de consistencia del ciclo,

  • Imagen \(X\) se pasa a través del generador \(G\) que los rendimientos generados imagen \(\hat{Y}\).
  • Imagen generada \(\hat{Y}\) se pasa a través de generador \(F\) que los rendimientos ciclan imagen \(\hat{X}\).
  • Error absoluto medio se calcula entre \(X\) y \(\hat{X}\).

\[forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}\]

\[backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}\]

Pérdida de ciclo

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

Como se muestra anteriormente, el generador \(G\) es responsable de traducir la imagen \(X\) a la imagen \(Y\). Pérdida de identidad dice que, si harto imagen \(Y\) al generador \(G\), debe dar la imagen real \(Y\) o algo cercano a la imagen \(Y\).

Si ejecuta el modelo de cebra a caballo en un caballo o el modelo de caballo a cebra en una cebra, no debería modificar mucho la imagen, ya que la imagen ya contiene la clase de destino.

\[Identity\ loss = |G(Y) - Y| + |F(X) - X|\]

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Inicialice los optimizadores para todos los generadores y discriminadores.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Puntos de control

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

Capacitación

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Aunque el ciclo de entrenamiento parece complicado, consta de cuatro pasos básicos:

  • Obtenga las predicciones.
  • Calcule la pérdida.
  • Calcula los gradientes usando retropropagación.
  • Aplica los degradados al optimizador.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 166.58266592025757 sec

Generar usando un conjunto de datos de prueba

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

Próximos pasos

Este tutorial ha mostrado cómo implementar CycleGAN a partir del generador y discriminador implementado en el Pix2Pix tutorial. Como siguiente paso, se podría tratar de usar un conjunto de datos diferente de TensorFlow conjuntos de datos .

También podría entrenar para un mayor número de épocas para mejorar los resultados, o se puede poner en práctica el generador ResNet modificado utilizado en el papel en lugar del generador de T-Net utiliza aquí.