Se usó la API de Cloud Translation para traducir esta página.
Switch to English

CycleGAN

Ver en TensorFlow.org Ejecutar en Google Colab Ver código fuente en GitHub Descargar cuaderno

Este cuaderno muestra la traducción de imagen a imagen sin emparejar usando GAN condicionales, como se describe en Traducción de imagen a imagen sin emparejar usando redes adversas consistentes en ciclo , también conocido como CycleGAN. El documento propone un método que puede capturar las características de un dominio de imagen y descubrir cómo estas características podrían traducirse a otro dominio de imagen, todo en ausencia de ejemplos de capacitación emparejados.

Este cuaderno supone que está familiarizado con Pix2Pix, que puede aprender en el tutorial de Pix2Pix . El código para CycleGAN es similar, la principal diferencia es una función de pérdida adicional y el uso de datos de entrenamiento no apareados.

CycleGAN utiliza una pérdida de consistencia del ciclo para permitir la capacitación sin la necesidad de datos emparejados. En otras palabras, puede traducir de un dominio a otro sin un mapeo uno a uno entre el dominio de origen y el de destino.

Esto abre la posibilidad de realizar muchas tareas interesantes como mejora de fotos, coloración de imágenes, transferencia de estilos, etc. Todo lo que necesita es el conjunto de datos de origen y destino (que es simplemente un directorio de imágenes).

Imagen de salida 1Imagen de salida 2

Configurar la tubería de entrada

Instale el paquete tensorflow_examples que permite importar el generador y el discriminador.

pip install -q git+https://github.com/tensorflow/examples.git
 import tensorflow as tf
 
 import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE
 

Tubería de entrada

Este tutorial entrena un modelo para traducir de imágenes de caballos a imágenes de cebras. Puede encontrar este conjunto de datos y otros similares aquí .

Como se menciona en el documento , aplique fluctuaciones y reflejos aleatorios al conjunto de datos de entrenamiento. Estas son algunas de las técnicas de aumento de imagen que evitan el sobreajuste.

Esto es similar a lo que se hizo en pix2pix

  • En la fluctuación aleatoria, la imagen cambia de tamaño a 286 x 286 y luego se recorta aleatoriamente a 256 x 256 .
  • En la duplicación aleatoria, la imagen se voltea aleatoriamente horizontalmente, es decir, de izquierda a derecha.
 dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
 
Downloading and preparing dataset cycle_gan/horse2zebra/2.0.0 (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-trainA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-trainB.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-testA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-testB.tfrecord
Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.

 BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
 
 def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
 
 # normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
 
 def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
 
 def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
 
 def preprocess_image_test(image, label):
  image = normalize(image)
  return image
 
 train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
 
 sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
 
 plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
 
<matplotlib.image.AxesImage at 0x7fab5c109f98>

png

 plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
 
<matplotlib.image.AxesImage at 0x7faaf830c748>

png

Importa y reutiliza los modelos Pix2Pix

Importe el generador y el discriminador utilizado en Pix2Pix a través del paquete de tensorflow_examples instalado.

La arquitectura del modelo utilizada en este tutorial es muy similar a la utilizada en pix2pix . Algunas de las diferencias son:

Aquí se entrenan 2 generadores (G y F) y 2 discriminadores (X e Y).

  • El generador G aprende a transformar la imagen X en la imagen Y $ (G: X -> Y) $
  • El generador F aprende a transformar la imagen Y en la imagen X $ (F: Y -> X) $
  • El discriminador D_X aprende a diferenciar entre la imagen X y la imagen generada X ( F(Y) ).
  • El discriminador D_Y aprende a diferenciar entre la imagen Y y la imagen generada Y ( G(X) ).

Modelo Cyclegan

 OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
 
 to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
 
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

 plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()
 

png

Funciones de pérdida

En CycleGAN, no hay datos emparejados para entrenar, por lo tanto, no hay garantía de que la entrada x el par objetivo y sean significativos durante el entrenamiento. Por lo tanto, para garantizar que la red aprenda el mapeo correcto, los autores proponen la pérdida de consistencia del ciclo.

La pérdida discriminadora y la pérdida del generador son similares a las utilizadas en pix2pix .

 LAMBDA = 10
 
 loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
 
 def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
 
 def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)
 

La consistencia del ciclo significa que el resultado debe estar cerca de la entrada original. Por ejemplo, si uno traduce una oración del inglés al francés, y luego la traduce nuevamente del francés al inglés, entonces la oración resultante debe ser la misma que la oración original.

En ciclo de pérdida de consistencia,

  • La imagen $ X $ se pasa a través del generador $ G $ que produce la imagen generada $ \ hat {Y} $.
  • La imagen generada $ \ hat {Y} $ se pasa a través del generador $ F $ que produce la imagen ciclada $ \ hat {X} $.
  • El error absoluto medio se calcula entre $ X $ y $ \ hat {X} $.
$$ adelante \ ciclo \ consistencia \ pérdida: X -> G (X) -> F (G (X)) \ sim \ hat {X} $$
$$ hacia atrás \ ciclo \ consistencia \ pérdida: Y -> F (Y) -> G (F (Y)) \ sim \ hat {Y} $$

Pérdida de ciclo

 def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1
 

Como se muestra arriba, el generador $ G $ es responsable de traducir la imagen $ X $ a la imagen $ Y $. La pérdida de identidad dice que, si alimentaste la imagen $ Y $ al generador $ G $, debería producir la imagen real $ Y $ o algo similar a la imagen $ Y $.

$$ Identidad \ pérdida = | G (Y) - Y | + | F (X) - X | $$
 def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss
 

Inicialice los optimizadores para todos los generadores y los discriminadores.

 generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
 

Puntos de control

 checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')
 

Formación

 EPOCHS = 40
 
 def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
 

Aunque el ciclo de entrenamiento parece complicado, consta de cuatro pasos básicos:

  • Obtén las predicciones.
  • Calcule la pérdida.
  • Calcule los gradientes usando la propagación hacia atrás.
  • Aplicar los gradientes al optimizador.
 @tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
 
 for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))
 

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 174.27903032302856 sec


Generar utilizando el conjunto de datos de prueba

 # Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)
 

png

png

png

png

png

Próximos pasos

Este tutorial ha mostrado cómo implementar CycleGAN a partir del generador y discriminador implementado en el tutorial de Pix2Pix . Como siguiente paso, puede intentar usar un conjunto de datos diferente del conjunto de datos TensorFlow .

También podría entrenar para una mayor cantidad de épocas para mejorar los resultados, o podría implementar el generador ResNet modificado que se utiliza en el documento en lugar del generador U-Net que se usa aquí.