หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

CycleGAN

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

สมุดบันทึกนี้แสดงภาพที่ไม่ได้จับคู่เป็นการแปลรูปภาพโดยใช้ GAN แบบมีเงื่อนไขตามที่อธิบายไว้ใน การแปลรูปภาพเป็นรูปภาพที่ไม่จับคู่โดยใช้ Cycle-Consistent Adversarial Networks หรือที่เรียกว่า CycleGAN บทความนี้นำเสนอวิธีการที่สามารถจับลักษณะของโดเมนรูปภาพหนึ่งและค้นหาว่าคุณสมบัติเหล่านี้สามารถแปลเป็นโดเมนรูปภาพอื่นได้อย่างไรทั้งหมดนี้ไม่มีตัวอย่างการฝึกอบรมที่จับคู่

สมุดบันทึกนี้ถือว่าคุณคุ้นเคยกับ Pix2Pix ซึ่งคุณสามารถเรียนรู้ได้จาก บทช่วยสอน Pix2Pix รหัสสำหรับ CycleGAN นั้นคล้ายคลึงกันข้อแตกต่างหลักคือฟังก์ชันการสูญเสียเพิ่มเติมและการใช้ข้อมูลการฝึกอบรมที่ไม่มีการจับคู่

CycleGAN ใช้การสูญเสียความสอดคล้องของวงจรเพื่อเปิดใช้งานการฝึกอบรมโดยไม่ต้องใช้ข้อมูลที่จับคู่ กล่าวอีกนัยหนึ่งก็คือสามารถแปลจากโดเมนหนึ่งไปยังอีกโดเมนหนึ่งได้โดยไม่ต้องมีการแมปแบบหนึ่งต่อหนึ่งระหว่างโดเมนต้นทางและโดเมนเป้าหมาย

สิ่งนี้จะเปิดโอกาสในการทำงานที่น่าสนใจมากมายเช่นการปรับแต่งรูปภาพการปรับสีภาพการถ่ายโอนสไตล์ ฯลฯ สิ่งที่คุณต้องมีคือแหล่งที่มาและชุดข้อมูลเป้าหมาย (ซึ่งเป็นเพียงไดเร็กทอรีของรูปภาพ)

ภาพที่ส่งออก 1ภาพที่ส่งออก 2

ตั้งค่าท่อส่งข้อมูล

ติดตั้งแพ็กเกจ tensorflow_examples ที่เปิดใช้งานการนำเข้าเครื่องกำเนิดไฟฟ้าและตัวแยกแยะ

pip install -q git+https://github.com/tensorflow/examples.git
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.experimental.AUTOTUNE

อินพุตไปป์ไลน์

บทช่วยสอนนี้ฝึกโมเดลในการแปลจากภาพม้าเป็นภาพม้าลาย คุณสามารถค้นหาชุดข้อมูลนี้และชุดข้อมูลที่คล้ายกันได้ที่ นี่

ตามที่ระบุไว้ใน กระดาษ ให้ใช้การกระตุกแบบสุ่มและการสะท้อนกับชุดข้อมูลการฝึกอบรม ต่อไปนี้คือเทคนิคการเพิ่มรูปภาพบางส่วนที่หลีกเลี่ยงการฟิตติ้งมากเกินไป

สิ่งนี้คล้ายกับที่ทำใน pix2pix

  • ในการกระตุกแบบสุ่มภาพจะถูกปรับขนาดเป็น 286 x 286 แล้วครอบตัดแบบสุ่มเป็น 256 x 256
  • ในการมิเรอร์แบบสุ่มภาพจะสุ่มพลิกในแนวนอนเช่นซ้ายไปขวา
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
Downloading and preparing dataset cycle_gan/horse2zebra/2.0.0 (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3I3N98/cycle_gan-trainA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3I3N98/cycle_gan-trainB.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3I3N98/cycle_gan-testA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3I3N98/cycle_gan-testB.tfrecord
Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f850280e048>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f84b015e400>

png

นำเข้าและใช้โมเดล Pix2Pix ซ้ำ

นำเข้าเครื่องกำเนิดไฟฟ้าและตัวจำแนกที่ใช้ใน Pix2Pix ผ่านแพ็คเกจ tensorflow_examples ที่ ติดตั้ง

สถาปัตยกรรมโมเดลที่ใช้ในบทช่วยสอนนี้คล้ายกับที่ใช้ใน pix2pix มาก ความแตกต่างบางประการ ได้แก่ :

มีเครื่องกำเนิดไฟฟ้า 2 เครื่อง (G และ F) และเครื่องจำแนก (X และ Y) 2 เครื่องกำลังได้รับการฝึกฝนที่นี่

  • Generator G เรียนรู้ที่จะแปลงภาพ X เป็นภาพ Y $ (G: X -> Y) $
  • Generator F เรียนรู้ที่จะเปลี่ยนรูปภาพ Y เป็นรูปภาพ X $ (F: Y -> X) $
  • Discriminator D_X เรียนรู้ที่จะแยกความแตกต่างระหว่างภาพ X และภาพที่สร้างขึ้น X ( F(Y) )
  • Discriminator D_Y เรียนรู้ที่จะแยกความแตกต่างระหว่างภาพ Y และภาพที่สร้างขึ้น Y ( G(X) )

แบบจำลอง Cyclegan

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

ฟังก์ชั่นการสูญเสีย

ใน CycleGAN ไม่มีจับคู่ข้อมูลในการฝึกอบรมในจึงไม่มีการรับประกันว่าไม่มีการป้อนข้อมูล x และเป้าหมาย y ทั้งคู่มีความหมายระหว่างการฝึกอบรม ดังนั้นเพื่อบังคับให้เครือข่ายเรียนรู้การทำแผนที่ที่ถูกต้องผู้เขียนจึงเสนอการสูญเสียความสอดคล้องของวงจร

การสูญเสียตัวแยกแยะและการสูญเสียของเครื่องกำเนิดจะคล้ายกับที่ใช้ใน pix2pix

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

ความสม่ำเสมอของวงจรหมายความว่าผลลัพธ์ควรใกล้เคียงกับอินพุตต้นฉบับ ตัวอย่างเช่นถ้าใครแปลประโยคจากภาษาอังกฤษเป็นภาษาฝรั่งเศสแล้วแปลกลับจากภาษาฝรั่งเศสเป็นภาษาอังกฤษประโยคที่ได้ก็ควรจะเหมือนกับประโยคต้นฉบับ

ในการสูญเสียความสม่ำเสมอของวงจร

  • อิมเมจ $ X $ ถูกส่งผ่านตัวสร้าง $ G $ ที่ให้ภาพ $ \ hat {Y} $ ที่สร้างขึ้น
  • รูปภาพที่สร้างขึ้น $ \ hat {Y} $ จะถูกส่งผ่านเครื่องกำเนิดไฟฟ้า $ F $ ที่ให้ภาพที่วนรอบ $ \ hat {X} $
  • ค่าเฉลี่ยข้อผิดพลาดสัมบูรณ์คำนวณระหว่าง $ X $ และ $ \ hat {X} $
$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$
$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$

การสูญเสียวงจร

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

ดังที่แสดงไว้ด้านบนตัวสร้าง $ G $ รับผิดชอบในการแปลรูปภาพ $ X $ เป็นรูปภาพ $ Y $ การสูญเสียข้อมูลประจำตัวกล่าวว่าหากคุณป้อนรูปภาพ $ Y $ เพื่อสร้าง $ G $ มันควรจะให้ภาพจริง $ Y $ หรืออะไรที่ใกล้เคียงกับรูปภาพ $ Y $

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

เริ่มต้นเครื่องมือเพิ่มประสิทธิภาพสำหรับเครื่องกำเนิดไฟฟ้าและตัวแยกแยะทั้งหมด

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

จุดตรวจ

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

การฝึกอบรม

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

แม้ว่าลูปการฝึกจะดูซับซ้อน แต่ก็ประกอบด้วยขั้นตอนพื้นฐานสี่ขั้นตอน:

  • รับคำทำนาย
  • คำนวณการสูญเสีย
  • คำนวณการไล่ระดับสีโดยใช้ backpropagation
  • ใช้การไล่ระดับสีกับเครื่องมือเพิ่มประสิทธิภาพ
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 174.7995924949646 sec


สร้างโดยใช้ชุดข้อมูลทดสอบ

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

ขั้นตอนถัดไป

บทช่วยสอนนี้แสดงวิธีการใช้งาน CycleGAN โดยเริ่มจากตัวสร้างและตัวเลือกที่นำไปใช้ในบทช่วยสอน Pix2Pix ในขั้นตอนต่อไปคุณสามารถลองใช้ชุดข้อมูลอื่นจาก TensorFlow Datasets

นอกจากนี้คุณยังสามารถฝึกอบรมเป็นจำนวนมากเพื่อปรับปรุงผลลัพธ์หรือคุณสามารถใช้เครื่องกำเนิด ResNet ที่แก้ไขแล้วซึ่งใช้ใน กระดาษ แทนเครื่องกำเนิด U-Net ที่ใช้ที่นี่