Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

CycleGAN

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Notebook ini mendemonstrasikan terjemahan gambar ke gambar yang tidak dipasangkan menggunakan GAN bersyarat, seperti yang dijelaskan dalam Terjemahan Gambar-ke-Gambar yang Tidak Dipasangkan menggunakan Jaringan Adversarial yang Konsisten Siklus , juga dikenal sebagai CycleGAN. Makalah ini mengusulkan metode yang dapat menangkap karakteristik satu domain gambar dan mencari tahu bagaimana karakteristik ini dapat diterjemahkan ke dalam domain gambar lain, semua tanpa adanya contoh pelatihan berpasangan.

Notebook ini menganggap Anda sudah familiar dengan Pix2Pix, yang dapat Anda pelajari di tutorial Pix2Pix . Kode untuk CycleGAN serupa, perbedaan utamanya adalah fungsi kerugian tambahan, dan penggunaan data pelatihan yang tidak berpasangan.

CycleGAN menggunakan kehilangan konsistensi siklus untuk mengaktifkan pelatihan tanpa memerlukan data berpasangan. Dengan kata lain, ini dapat menerjemahkan dari satu domain ke domain lainnya tanpa pemetaan satu-ke-satu antara domain sumber dan target.

Ini membuka kemungkinan untuk melakukan banyak tugas menarik seperti peningkatan foto, pewarnaan gambar, transfer gaya, dll. Yang Anda butuhkan hanyalah sumber dan kumpulan data target (yang hanya berupa direktori gambar).

Gambar Keluaran 1Gambar Keluaran 2

Siapkan pipeline masukan

Instal paket tensorflow_examples yang memungkinkan impor generator dan diskriminator.

pip install -q git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.experimental.AUTOTUNE

Pipa Masukan

Tutorial ini melatih model untuk menerjemahkan dari gambar kuda, ke gambar zebra. Anda dapat menemukan kumpulan data ini dan yang serupa di sini .

Seperti disebutkan di makalah , terapkan jittering dan pencerminan acak ke set data pelatihan. Ini adalah beberapa teknik augmentasi gambar yang menghindari overfitting.

Ini mirip dengan apa yang dilakukan di pix2pix

  • Dalam jittering acak, gambar diubah ukurannya menjadi 286 x 286 dan kemudian dipotong secara acak menjadi 256 x 256 .
  • Dalam pencerminan acak, gambar dibalik secara acak secara horizontal, yaitu dari kiri ke kanan.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
Downloading and preparing dataset cycle_gan/horse2zebra/2.0.0 (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteI5EPXN/cycle_gan-trainA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteI5EPXN/cycle_gan-trainB.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteI5EPXN/cycle_gan-testA.tfrecord
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incompleteI5EPXN/cycle_gan-testB.tfrecord
Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f0e929e6d68>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f0e92900e48>

png

Impor dan gunakan kembali model Pix2Pix

Impor generator dan diskriminator yang digunakan di Pix2Pix melalui paket tensorflow_examples .

Arsitektur model yang digunakan dalam tutorial ini sangat mirip dengan yang digunakan di pix2pix . Beberapa perbedaannya adalah:

Ada 2 generator (G dan F) dan 2 diskriminator (X dan Y) sedang dilatih di sini.

  • Generator G belajar mengubah gambar X menjadi gambar Y $ (G: X -> Y) $
  • Generator F belajar mengubah gambar Y menjadi gambar X $ (F: Y -> X) $
  • Diskriminator D_X belajar membedakan antara gambar X dan gambar yang dihasilkan X ( F(Y) ).
  • Diskriminator D_Y belajar membedakan antara citra Y dan citra yang dihasilkan Y ( G(X) ).

Model Cyclegan

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

Fungsi kerugian

Di CycleGAN, tidak ada data berpasangan untuk dilatih, oleh karena itu tidak ada jaminan bahwa pasangan input x dan target y bermakna selama pelatihan. Jadi untuk menegakkan bahwa jaringan mempelajari pemetaan yang benar, penulis mengusulkan kehilangan konsistensi siklus.

Kerugian diskriminator dan kerugian generator mirip dengan yang digunakan di pix2pix .

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Konsistensi siklus berarti hasil harus mendekati masukan asli. Misalnya, jika seseorang menerjemahkan kalimat dari bahasa Inggris ke bahasa Prancis, lalu menerjemahkannya kembali dari bahasa Prancis ke bahasa Inggris, maka kalimat yang dihasilkan harus sama dengan kalimat aslinya.

Dalam kehilangan konsistensi siklus,

  • Gambar $ X $ diteruskan melalui generator $ G $ yang menghasilkan gambar $ \ hat {Y} $.
  • Gambar yang dihasilkan $ \ hat {Y} $ diteruskan melalui generator $ F $ yang menghasilkan gambar siklus $ \ hat {X} $.
  • Kesalahan absolut rata-rata dihitung antara $ X $ dan $ \ hat {X} $.
$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$
$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$

Kehilangan siklus

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

Seperti yang ditunjukkan di atas, generator $ G $ bertanggung jawab untuk menerjemahkan gambar $ X $ ke gambar $ Y $. Kehilangan identitas mengatakan bahwa, jika Anda memasukkan gambar $ Y $ ke generator $ G $, itu akan menghasilkan gambar asli $ Y $ atau sesuatu yang mendekati gambar $ Y $.

Jika Anda menjalankan model zebra-to-horse pada kuda atau model kuda-ke-zebra pada zebra, model tersebut tidak boleh banyak memodifikasi gambar karena gambar tersebut sudah berisi kelas target.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Inisialisasi pengoptimal untuk semua generator dan diskriminator.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Pos pemeriksaan

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

Latihan

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Meskipun loop pelatihan terlihat rumit, ini terdiri dari empat langkah dasar:

  • Dapatkan prediksinya.
  • Hitung kerugiannya.
  • Hitung gradien menggunakan propagasi mundur.
  • Terapkan gradien ke pengoptimal.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 174.97624921798706 sec


Hasilkan menggunakan set data pengujian

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

Langkah selanjutnya

Tutorial ini telah menunjukkan bagaimana mengimplementasikan CycleGAN mulai dari generator dan diskriminator yang diimplementasikan dalam tutorial Pix2Pix . Sebagai langkah selanjutnya, Anda dapat mencoba menggunakan kumpulan data yang berbeda dari Kumpulan Data TensorFlow .

Anda juga dapat melatih jumlah epoch yang lebih besar untuk meningkatkan hasil, atau Anda dapat menerapkan generator ResNet yang dimodifikasi yang digunakan di koran daripada generator U-Net yang digunakan di sini.