Esta página foi traduzida pela API Cloud Translation.
Switch to English

CycleGAN

Ver em TensorFlow.org Executar no Google Colab Ver fonte no GitHub Download do caderno

Este bloco de notas demonstra a tradução de imagem para imagem não emparelhada usando GANs condicionais, conforme descrito em Tradução de imagem para imagem não emparelhada usando redes adversas consistentes em ciclo , também conhecido como CycleGAN. O artigo propõe um método que pode capturar as características de um domínio de imagem e descobrir como essas características podem ser traduzidas em outro domínio de imagem, tudo na ausência de exemplos de treinamento emparelhados.

Este caderno pressupõe que você esteja familiarizado com o Pix2Pix, sobre o qual você pode aprender no tutorial do Pix2Pix . O código para o CycleGAN é semelhante, a principal diferença é uma função de perda adicional e o uso de dados de treinamento não emparelhados.

O CycleGAN usa uma perda de consistência do ciclo para permitir o treinamento sem a necessidade de dados emparelhados. Em outras palavras, ele pode converter de um domínio para outro sem um mapeamento individual entre o domínio de origem e o destino.

Isso abre a possibilidade de realizar muitas tarefas interessantes, como aprimoramento de fotos, colorização de imagens, transferência de estilos, etc. Tudo o que você precisa é da fonte e do conjunto de dados de destino (que é simplesmente um diretório de imagens).

Imagem de saída 1Imagem de saída 2

Configurar o pipeline de entrada

Instale o pacote tensorflow_examples que permite a importação do gerador e do discriminador.

pip install -q git+https://github.com/tensorflow/examples.git
 import tensorflow as tf
 
 import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE
 

Pipeline de entrada

Este tutorial treina um modelo para traduzir de imagens de cavalos para imagens de zebras. Você pode encontrar esse conjunto de dados e similares aqui .

Conforme mencionado no artigo , aplique tremulação e espelhamento aleatórios no conjunto de dados de treinamento. Estas são algumas das técnicas de aumento de imagem que evitam o ajuste excessivo.

Isso é semelhante ao que foi feito no pix2pix

  • No tremor aleatório, a imagem é redimensionada para 286 x 286 e cortada aleatoriamente para 256 x 256 .
  • No espelhamento aleatório, a imagem é invertida horizontalmente, ou seja, da esquerda para a direita.
 dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
 
Downloading and preparing dataset cycle_gan/horse2zebra/2.0.0 (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-trainA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-trainB.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-testA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteKMK6GL/cycle_gan-testB.tfrecord
Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.

 BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
 
 def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
 
 # normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
 
 def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
 
 def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
 
 def preprocess_image_test(image, label):
  image = normalize(image)
  return image
 
 train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
 
 sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
 
 plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
 
<matplotlib.image.AxesImage at 0x7fab5c109f98>

png

 plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
 
<matplotlib.image.AxesImage at 0x7faaf830c748>

png

Importar e reutilizar os modelos Pix2Pix

Importe o gerador e o discriminador usado no Pix2Pix através do pacote tensorflow_examples instalado.

A arquitetura do modelo usada neste tutorial é muito semelhante à usada no pix2pix . Algumas das diferenças são:

Existem 2 geradores (G e F) e 2 discriminadores (X e Y) sendo treinados aqui.

  • O gerador G aprende a transformar a imagem X na imagem Y $ (G: X -> Y) $
  • O gerador F aprende a transformar a imagem Y na imagem X $ (F: Y -> X) $
  • O discriminador D_X aprende a diferenciar entre a imagem X e a imagem gerada X ( F(Y) ).
  • O discriminador D_Y aprende a diferenciar entre a imagem Y e a imagem gerada Y ( G(X) ).

Modelo cyclegan

 OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
 
 to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
 
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

 plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()
 

png

Funções de perda

No CycleGAN, não há dados emparelhados para treinamento, portanto, não há garantia de que a entrada x e o par de destino y sejam significativos durante o treinamento. Assim, para garantir que a rede aprenda o mapeamento correto, os autores propõem a perda de consistência do ciclo.

A perda discriminadora e a perda do gerador são semelhantes às usadas no pix2pix .

 LAMBDA = 10
 
 loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
 
 def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
 
 def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)
 

Consistência do ciclo significa que o resultado deve estar próximo da entrada original. Por exemplo, se alguém traduz uma frase de inglês para francês e depois a traduz de francês para inglês, a frase resultante deve ser igual à frase original.

Na perda de consistência do ciclo,

  • A imagem $ X $ é passada pelo gerador $ G $ que gera a imagem gerada $ \ hat {Y} $.
  • A imagem gerada $ \ hat {Y} $ é passada pelo gerador $ F $ que gera a imagem ciclada $ \ hat {X} $.
  • O erro absoluto médio é calculado entre $ X $ e $ \ hat {X} $.
$$ frente \ ciclo \ consistência \ perda: X -> G (X) -> F (G (X)) \ sim \ hat {X} $$
$$ retrocesso \ ciclo \ consistência \ perda: Y -> F (Y) -> G (F (Y)) \ sim \ hat {Y} $$

Perda de ciclo

 def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1
 

Como mostrado acima, o gerador $ G $ é responsável por converter a imagem $ X $ na imagem $ Y $. A perda de identidade diz que, se você alimentou a imagem $ Y $ ao gerador $ G $, ela renderia a imagem real $ Y $ ou algo próximo à imagem $ Y $.

$$ Identidade \ perda = | G (Y) - Y | + | F (X) - X | $$
 def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss
 

Inicialize os otimizadores para todos os geradores e discriminadores.

 generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
 

Pontos de verificação

 checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')
 

Treinamento

 EPOCHS = 40
 
 def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
 

Embora o ciclo de treinamento pareça complicado, ele consiste em quatro etapas básicas:

  • Receba as previsões.
  • Calcular a perda.
  • Calcule os gradientes usando a retropropagação.
  • Aplique os gradientes ao otimizador.
 @tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
 
 for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))
 

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 174.27903032302856 sec


Gerar usando o conjunto de dados de teste

 # Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)
 

png

png

png

png

png

Próximos passos

Este tutorial mostrou como implementar o CycleGAN a partir do gerador e discriminador implementado no tutorial do Pix2Pix . Como próxima etapa, você pode tentar usar um conjunto de dados diferente dos conjuntos de dados TensorFlow .

Você também pode treinar para um número maior de épocas para melhorar os resultados ou implementar o gerador ResNet modificado usado no documento, em vez do gerador U-Net usado aqui.