Google I / O kehrt vom 18. bis 20. Mai zurück! Reservieren Sie Platz und erstellen Sie Ihren Zeitplan Registrieren Sie sich jetzt
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

CycleGAN

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Dieses Notizbuch demonstriert die ungepaarte Bild-zu-Bild-Übersetzung mit bedingten GANs , wie unter Ungepaarte Bild-zu-Bild-Übersetzung mit zykluskonsistenten kontradiktorischen Netzwerken , auch als CycleGAN bezeichnet, beschrieben. In diesem Artikel wird eine Methode vorgeschlagen, mit der die Merkmale einer Bilddomäne erfasst und herausgefunden werden können, wie diese Merkmale in eine andere Bilddomäne übersetzt werden können, wenn keine paarigen Trainingsbeispiele vorhanden sind.

In diesem Notizbuch wird davon ausgegangen, dass Sie mit Pix2Pix vertraut sind, das Sie im Pix2Pix-Tutorial kennenlernen können . Der Code für CycleGAN ist ähnlich, der Hauptunterschied ist eine zusätzliche Verlustfunktion und die Verwendung ungepaarter Trainingsdaten.

CycleGAN verwendet einen Zykluskonsistenzverlust, um das Training zu ermöglichen, ohne dass gepaarte Daten erforderlich sind. Mit anderen Worten, es kann von einer Domäne in eine andere übersetzt werden, ohne dass eine Eins-zu-Eins-Zuordnung zwischen der Quell- und der Zieldomäne erfolgt.

Dies eröffnet die Möglichkeit, viele interessante Aufgaben wie Fotoverbesserung, Bildfärbung, Stilübertragung usw. auszuführen. Sie benötigen lediglich den Quell- und den Zieldatensatz (der einfach ein Verzeichnis von Bildern ist).

Ausgabebild 1Ausgabebild 2

Richten Sie die Eingabepipeline ein

Installieren Sie das Paket tensorflow_examples , mit dem der Generator und der Diskriminator importiert werden können.

pip install -q git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

Eingabe-Pipeline

In diesem Tutorial wird ein Modell trainiert, um Bilder von Pferden in Bilder von Zebras zu übersetzen. Diesen und ähnliche Datensätze finden Sie hier .

Wie in dem genannten Papier gelten zufällige Flackern und Spiegelung auf den Trainingsdaten. Dies sind einige der Bildvergrößerungstechniken, die eine Überanpassung vermeiden.

Dies ähnelt dem, was in pix2pix gemacht wurde

  • Bei zufälligem Jitter wird die Bildgröße auf 286 x 286 und dann zufällig auf 256 x 256 .
  • Bei der zufälligen Spiegelung wird das Bild zufällig horizontal gespiegelt, dh von links nach rechts.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f299b0f1eb8>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f299af90fd0>

png

Importieren und verwenden Sie die Pix2Pix-Modelle erneut

Importieren Sie den in Pix2Pix verwendeten Generator und Diskriminator über das installierte Paket tensorflow_examples .

Die in diesem Lernprogramm verwendete Modellarchitektur ist der in pix2pix verwendeten sehr ähnlich. Einige der Unterschiede sind:

Hier werden 2 Generatoren (G und F) und 2 Diskriminatoren (X und Y) trainiert.

  • Generator G lernt, Bild X in Bild Y umzuwandeln. $ (G: X -> Y) $
  • Generator F lernt, Bild Y in Bild X umzuwandeln. $ (F: Y -> X) $
  • Der Diskriminator D_X lernt, zwischen Bild X und erzeugtem Bild X ( F(Y) ) zu unterscheiden.
  • Der Diskriminator D_Y lernt, zwischen Bild Y und erzeugtem Bild Y ( G(X) ) zu unterscheiden.

Cyclegan-Modell

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

Verlustfunktionen

In CycleGAN gibt es keine gepaarten Daten zum Trainieren, daher gibt es keine Garantie dafür, dass das Eingabe- x und das Ziel- y Paar während des Trainings von Bedeutung sind. Um zu erzwingen, dass das Netzwerk die richtige Zuordnung lernt, schlagen die Autoren den Verlust der Zykluskonsistenz vor.

Der Diskriminatorverlust und der Generatorverlust ähneln denen, die in pix2pix verwendet werden .

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Zykluskonsistenz bedeutet, dass das Ergebnis nahe an der ursprünglichen Eingabe liegen sollte. Wenn man beispielsweise einen Satz vom Englischen ins Französische übersetzt und ihn dann vom Französischen ins Englische zurückübersetzt, sollte der resultierende Satz mit dem ursprünglichen Satz identisch sein.

Bei Verlust der Zykluskonsistenz

  • Das Bild $ X $ wird über den Generator $ G $ übergeben, der das erzeugte Bild $ \ hat {Y} $ liefert.
  • Das generierte Bild $ \ hat {Y} $ wird über den Generator $ F $ übergeben, der das zyklische Bild $ \ hat {X} $ liefert.
  • Der mittlere absolute Fehler wird zwischen $ X $ und $ \ hat {X} $ berechnet.
$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$
$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$

Zyklusverlust

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

Wie oben gezeigt, ist der Generator $ G $ für die Übersetzung des Bildes $ X $ in das Bild $ Y $ verantwortlich. Identitätsverlust besagt, dass, wenn Sie Bild $ Y $ dem Generator $ G $ zugeführt haben, es das reale Bild $ Y $ oder etwas in der Nähe von Bild $ Y $ ergeben sollte.

Wenn Sie das Zebra-zu-Pferd-Modell auf einem Pferd oder das Pferd-zu-Zebra-Modell auf einem Zebra ausführen, sollte das Bild nicht wesentlich geändert werden, da das Bild bereits die Zielklasse enthält.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Initialisieren Sie die Optimierer für alle Generatoren und Diskriminatoren.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Checkpoints

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

Ausbildung

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Obwohl die Trainingsschleife kompliziert aussieht, besteht sie aus vier grundlegenden Schritten:

  • Holen Sie sich die Vorhersagen.
  • Berechnen Sie den Verlust.
  • Berechnen Sie die Verläufe mit Backpropagation.
  • Wenden Sie die Farbverläufe auf den Optimierer an.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 169.29227685928345 sec

Generieren Sie mit dem Testdatensatz

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

Nächste Schritte

Dieses Tutorial hat gezeigt, wie CycleGAN ausgehend von dem im Pix2Pix- Tutorial implementierten Generator und Diskriminator implementiert wird. Als nächsten Schritt können Sie versuchen, ein anderes Dataset als TensorFlow-Datasets zu verwenden .

Sie können auch eine größere Anzahl von Epochen trainieren, um die Ergebnisse zu verbessern, oder Sie können den im Papier verwendeten modifizierten ResNet-Generator anstelle des hier verwendeten U-Net-Generators implementieren.