עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

CycleGAN

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

מחברת זו ממחישה תמונה מזווגת תרגום תמונה באמצעות גאן המותנה, כמתואר תרגום Image-to-התמונה מזווג באמצעות רשתות מחזור-עקב תשובה: , הידוע גם בשם CycleGAN. המאמר מציע שיטה שיכולה ללכוד את המאפיינים של תחום תמונה אחד ולהבין כיצד ניתן לתרגם את המאפיינים הללו לתחום תמונה אחר, כל זאת בהיעדר דוגמאות להכשרה זוגיות.

מחברת זו מניחה שאתה מכיר Pix2Pix, שבו אתה יכול ללמוד על ב הדרכת Pix2Pix . הקוד של CycleGAN דומה, ההבדל העיקרי הוא פונקציית אובדן נוספת, ושימוש בנתוני אימון לא מזווגים.

CycleGAN משתמש באובדן עקביות במחזור כדי לאפשר אימון ללא צורך בנתונים מותאמים. במילים אחרות, הוא יכול לתרגם מתחום אחד לאחר ללא מיפוי אחד לאחד בין תחום המקור לתחום היעד.

זה פותח את האפשרות לבצע הרבה משימות מעניינות כמו שיפור תמונות, צביעת תמונה, העברת סגנון וכו'. כל מה שאתה צריך זה את המקור ואת מערך הנתונים של היעד (שהוא פשוט ספריית תמונות).

תמונת פלט 1תמונת פלט 2

הגדר את צינור הקלט

התקן את tensorflow_examples החבילה המאפשרת יבוא של הגנרטור ואת המאבחן.

pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

צינור קלט

הדרכה זו מכשירה מודל לתרגם מתמונות של סוסים, לתמונות של זברות. אתה יכול למצוא במערך הזה ודומים כאן .

כפי שהוזכר נייר , להחיל jittering אקראי שיקוף אל במערך ההדרכה. אלו הן חלק מטכניקות הגדלת התמונה המונעות התאמה יתר.

זה דומה למה שנעשה pix2pix

  • בשנת jittering אקראית, התמונה משתנה כדי 286 x 286 ואז קצוץ באופן אקראי 256 x 256 .
  • בשיקוף אקראי, התמונה מתהפכת באופן אקראי אופקית כלומר משמאל לימין.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fd518202090>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fd5107cea90>

png

יבוא ועשה שימוש חוזר בדגמי Pix2Pix

ייבא את הגנרטור ואת המאבחן המשמש Pix2Pix דרך מותקן tensorflow_examples החבילה.

ארכיטקטורת המודל המשמשת הדרכה זו היא מאוד דומה למה שמש pix2pix . חלק מההבדלים הם:

יש כאן 2 גנרטורים (G ו-F) ו-2 מפלים (X ו-Y).

  • מחולל G לומד לשנות את התמונה X לתמונה Y . \((G: X -> Y)\)
  • מחולל F לומד לשנות את התמונה Y לתמונה X . \((F: Y -> X)\)
  • מאבחן D_X לומד להבחין בין התמונה X ותמונה שנוצר X ( F(Y) ).
  • מאבחן D_Y לומד להבחין בין התמונה Y ו- שנוצר דימוי Y ( G(X) ).

דגם Cyclegan

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

פונקציות הפסד

בשנת CycleGAN, אין הוא זיווג נתונים להכשיר על, ומכאן שאין כל ערובה כי קלט x ואת היעד y זוג בעלי משמעות במהלך האימונים. לפיכך, על מנת לאכוף שהרשת תלמד את המיפוי הנכון, הכותבים מציעים את אובדן עקביות המחזור.

ההפסד מאבחן ואת אובדן גנרטור דומה לאלה המשמשים pix2pix .

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

עקביות מחזור פירושה שהתוצאה צריכה להיות קרובה לקלט המקורי. לדוגמה, אם מתרגמים משפט מאנגלית לצרפתית, ואז מתרגמים אותו בחזרה מצרפתית לאנגלית, אז המשפט המתקבל צריך להיות זהה למשפט המקורי.

באובדן עקביות במחזור,

  • תמונת \(X\) מועבר באמצעות גנרטור \(G\) כי תשואות שנוצרו דימוי \(\hat{Y}\).
  • תמונה Generated \(\hat{Y}\) מועבר באמצעות גנרטור \(F\) כי תשואות אופניו תמונה \(\hat{X}\).
  • שגיאה מוחלטת ממוצעת מחושבת בין \(X\) ו \(\hat{X}\).

\[forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}\]

\[backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}\]

אובדן מחזור

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

כפי שהראה, גנרטור \(G\) אחראי תרגום התמונה \(X\) לתמונה \(Y\). אובדן זהות, אומר כי, אם מאכילים תמונה \(Y\) כדי גנרטור \(G\), זה אמור להניב את התמונה האמיתית \(Y\) קרוב או משהו תמונת \(Y\).

אם אתה מריץ את המודל של זברה לסוס על סוס או את המודל של סוס לזברה על זברה, זה לא צריך לשנות את התמונה הרבה מכיוון שהתמונה כבר מכילה את מחלקת היעד.

\[Identity\ loss = |G(Y) - Y| + |F(X) - X|\]

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

אתחל את המיטובים עבור כל המחוללים והמפלים.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

מחסומים

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

הַדְרָכָה

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

למרות שלולאת האימון נראית מסובכת, היא מורכבת מארבעה שלבים בסיסיים:

  • קבלו את התחזיות.
  • חשב את ההפסד.
  • חשב את השיפועים באמצעות התפשטות לאחור.
  • החל את ההדרגות על כלי האופטימיזציה.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 166.58266592025757 sec

צור באמצעות מערך נתונים לבדיקה

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

הצעדים הבאים

הדרכה זו מלמדת כיצד ליישם CycleGAN החל גנרטור ו מאבחן מיושם Pix2Pix הדרכה. כצעד הבא, אתה יכול לנסות להשתמש במערך שונה מערכי נתונים TensorFlow .

אתה גם יכול לאמן עבור מספר גדול יותר של תקופות לשפר את התוצאות, או שאתה יכול ליישם את הגנרטור ResNet השונה בשימוש בנייר במקום גנרטור U-Net משמש כאן.