דף זה תורגם על ידי Cloud Translation API.
Switch to English

CycleGAN

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

מחברת זו מדגימה תמונה לא מזווגת לתרגום תמונות באמצעות GAN מותנה, כמתואר בתרגום תמונה-ל-תמונה לא מזווג באמצעות רשתות יריבות מחזוריות, המכונות גם CycleGAN. העיתון מציע שיטה שתוכל ללכוד את המאפיינים של תחום תמונה אחד ולהבין כיצד ניתן לתרגם מאפיינים אלה לתחום תמונה אחר, כל זאת בהיעדר דוגמאות אימון זוגיות.

מחברת זו מניחה שאתה מכיר את Pix2Pix, שעליו תוכל ללמוד במדריך Pix2Pix . הקוד ל- CycleGAN דומה, ההבדל העיקרי הוא פונקציית אובדן נוספת ושימוש בנתוני אימון לא מזווגים.

CycleGAN משתמש באובדן עקביות של מחזור כדי לאפשר אימונים ללא צורך בנתונים זוגיים. במילים אחרות, זה יכול לתרגם מדומיין אחד לאחר ללא מיפוי אחד לאחד בין המקור לדומיין היעד.

זה פותח את האפשרות לבצע הרבה משימות מעניינות כמו שיפור תמונות, צבעי תמונה, העברת סגנונות וכו 'כל מה שאתה צריך זה המקור ומערך היעד (שהוא פשוט ספריה של תמונות).

תמונת פלט 1תמונת פלט 2

הגדר את צינור הקלט

התקן את חבילת tensorflow_examples המאפשרת ייבוא ​​של הגנרטור והמפלה.

pip install -q git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

צינור קלט

הדרכה זו מכשירה מודל לתרגום מתמונות של סוסים, לתמונות של זברות. תוכל למצוא את מערך הנתונים הזה ודומים לו כאן .

כאמור בעיתון , החל רטט ושיקוף אקראי על מערך האימונים. אלה חלק מטכניקות הגדלת התמונה המונעות התאמת יתר.

זה דומה למה שנעשה ב- pix2pix

  • ברטט אקראי, גודל התמונה משתנה ל 286 x 286 ואז נחתך באופן אקראי ל 256 x 256 .
  • בשיקוף אקראי, התמונה הופכת באקראי אופקית כלומר משמאל לימין.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
Downloading and preparing dataset 111.45 MiB (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0...
Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fb9682d80f0>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fb968237208>

png

יבוא ושימוש חוזר בדגמי Pix2Pix

ייבא את הגנרטור ואת המפלה המשמשים ב- Pix2Pix באמצעות חבילת tensorflow_examples המותקנת.

ארכיטקטורת המודל המשמשת במדריך זה דומה מאוד למה ששימש ב- pix2pix . חלק מההבדלים הם:

יש כאן 2 גנרטורים (G ו- F) ו -2 מפלים (X ו- Y).

  • המחולל G לומד להפוך את התמונה X לתמונה Y $ (G: X -> Y) $
  • המחולל F לומד להפוך את התמונה Y לתמונה X $ (F: Y -> X) $
  • המפלה D_X לומד להבדיל בין תמונה X לתמונה שנוצרה X ( F(Y) ).
  • D_Y לומד להבדיל בין תמונה Y לתמונה שנוצרה Y ( G(X) ).

מודל סייקלגן

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

פונקציות אובדן

ב- CycleGAN אין נתונים זוגיים להתאמן עליהם, ולכן אין שום ערובה לכך שהקלט x וצמד ה y הם בעלי משמעות במהלך האימון. לכן על מנת לאכוף כי הרשת לומדת את המיפוי הנכון, הכותבים מציעים את אובדן עקביות המחזור.

אובדן המפלה ואובדן הגנרטור דומים לאלה המשמשים ב- pix2pix .

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

עקביות מחזור פירושה שהתוצאה צריכה להיות קרובה לקלט המקורי. לדוגמא, אם מתרגמים משפט מאנגלית לצרפתית ואז מתרגמים אותו חזרה מצרפתית לאנגלית, אז המשפט שהתקבל צריך להיות זהה למשפט המקורי.

באובדן עקביות מחזור,

  • תמונה $ X $ מועברת באמצעות גנרטור $ G $ שמניב תמונה שנוצרה $ \ hat {Y} $.
  • תמונה שנוצרה $ \ hat {Y} $ מועברת באמצעות גנרטור $ F $ שמניב תמונה מחזורית $ \ hat {X} $.
  • שגיאה מוחלטת ממוצעת מחושבת בין $ X $ ל- $ \ hat {X} $.
$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$
$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$

אובדן מחזור

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

כפי שמוצג לעיל, הגנרטור $ G $ אחראי לתרגום התמונה $ X $ לתמונה $ Y $. אובדן זהות אומר שאם הזנת תמונה $ Y $ לגנרטור $ G $, זה אמור להניב את התמונה האמיתית $ Y $ או משהו קרוב לתמונה $ Y $.

אם אתה מפעיל את דגם הזברה לסוס על סוס או את הדגם סוס לזברה על הזברה, זה לא אמור לשנות את התמונה הרבה מכיוון שהתמונה כבר מכילה את מחלקת היעד.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

אתחל את האופטימיזציה לכל הגנרטורים והמפלים.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

מחסומים

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

הַדְרָכָה

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

למרות שלולאת האימונים נראית מסובכת, היא מורכבת מארבעה שלבים בסיסיים:

  • קבל את התחזיות.
  • חשב את ההפסד.
  • חישוב הדרגתיות באמצעות backpropagation.
  • החל את הדרגתיות על האופטימיזציה.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 169.26107788085938 sec


צור באמצעות מערך בדיקה

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

הצעדים הבאים

מדריך זה הראה כיצד ליישם את CycleGAN החל מהגנרטור והמאבחן המיושם במדריך Pix2Pix . כשלב הבא, תוכל לנסות להשתמש במערך נתונים שונה ממערכי הנתונים של TensorFlow .

אתה יכול גם להתאמן למספר רב יותר של תקופות לשיפור התוצאות, או ליישם את מחולל ResNet ששונה בשימוש בעיתון במקום מחולל ה- U-Net המשמש כאן.