Google I / O'daki önemli notları, ürün oturumlarını, atölyeleri ve daha fazlasını izleyin Oynatma listesine bakın

CycleGAN

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

Bu not defteri, CycleGAN olarak da bilinen Cycle-Consistent Adversarial Networks kullanılarak Eşleştirilmemiş Görüntüden Görüntüye Çeviri'de açıklandığı gibi koşullu GAN'lar kullanılarak eşleştirilmemiş görüntüden görüntüye çevirisini gösterir. Makale, herhangi bir eşleştirilmiş eğitim örneğinin yokluğunda, bir görüntü alanının özelliklerini yakalayabilen ve bu özelliklerin başka bir görüntü alanına nasıl çevrilebileceğini çözebilen bir yöntem önermektedir.

Bu defter, Pix2Pix eğitiminde öğrenebileceğiniz Pix2Pix ile aşina olduğunuzu varsayar. CycleGAN kodu benzerdir, temel fark ek bir kayıp işlevi ve eşleşmemiş eğitim verilerinin kullanılmasıdır.

CycleGAN, eşleştirilmiş verilere ihtiyaç duymadan eğitimi etkinleştirmek için bir döngü tutarlılık kaybı kullanır. Başka bir deyişle, kaynak ve hedef alan arasında bire bir eşleştirme olmadan bir alandan diğerine çeviri yapabilir.

Bu, fotoğraf geliştirme, görüntü renklendirme, stil aktarımı vb. Gibi birçok ilginç görevi yerine getirme olanağını açar. İhtiyacınız olan tek şey kaynak ve hedef veri kümesidir (sadece bir görüntü dizini).

Çıktı Resmi 1Çıktı Resmi 2

Giriş ardışık düzenini ayarlayın

Oluşturucu ve diskriminatörün içe aktarılmasını sağlayan tensorflow_examples paketini kurun.

pip install -q git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

Giriş Boru Hattı

Bu eğitim, atların resimlerinden zebraların resimlerine çevirmek için bir model eğitiyor. Bu veri setini ve benzerlerini burada bulabilirsiniz .

Makalede belirtildiği gibi, eğitim veri kümesine rastgele değişim ve yansıtma uygulayın. Bunlar, aşırı uydurmayı önleyen görüntü büyütme tekniklerinden bazılarıdır.

Bu pix2pix'te yapılana benzer

  • Rastgele geçişte, görüntü 286 x 286 olarak yeniden boyutlandırılır ve ardından rastgele 256 x 256 kırpılır.
  • Rastgele aynalamada, görüntü rasgele yatay olarak, yani soldan sağa çevrilir.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fe8ac184e10>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fe8ac248550>

png

Pix2Pix modellerini içe aktarın ve yeniden kullanın

Pix2Pix'te kullanılan jeneratör ve ayırıcıyı kurulu tensorflow_examples paketi aracılığıyla içe aktarın.

Bu eğitimde kullanılan model mimarisi, pix2pix'te kullanılana çok benzer. Bazı farklılıklar şunlardır:

Burada eğitilen 2 jeneratör (G ve F) ve 2 ayırıcı (X ve Y) vardır.

  • Generator G , X görüntüsünü Y görüntüsüne dönüştürmeyi öğrenir. $ (G: X -> Y) $
  • Jeneratör F , Y görüntüsünü X görüntüsüne dönüştürmeyi öğrenir. $ (F: Y -> X) $
  • Ayırıcı D_X , görüntü X ile oluşturulan görüntü X ( F(Y) ) arasında D_X öğrenir.
  • Ayırıcı D_Y , Y görüntüsü ile oluşturulan Y görüntüsü ( G(X) ) arasında D_Y öğrenir.

Cyclegan modeli

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

Kayıp fonksiyonları

CycleGAN'da eğitilecek eşleştirilmiş veri yoktur, bu nedenle eğitim sırasında x girişi ve hedef y çiftinin anlamlı olacağının garantisi yoktur. Bu nedenle, ağın doğru eşlemeyi öğrenmesini sağlamak için yazarlar döngü tutarlılık kaybını önermektedir.

Ayırıcı kaybı ve jeneratör kaybı pix2pix'te kullanılanlara benzer.

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Döngü tutarlılığı, sonucun orijinal girdiye yakın olması gerektiği anlamına gelir. Örneğin, bir cümle İngilizceden Fransızcaya çevrilirse ve daha sonra Fransızcadan İngilizceye çevrilirse, ortaya çıkan cümle orijinal cümle ile aynı olmalıdır.

Döngüde tutarlılık kaybı,

  • $ X $ görüntüsü, $ \ hat {Y} $ oluşturulmuş görüntüsünü veren $ G $ oluşturucu aracılığıyla geçirilir.
  • Oluşturulan $ \ hat {Y} $ görüntüsü, $ \ hat {X} $ döngülü görüntüsünü veren $ F $ oluşturucu aracılığıyla geçirilir.
  • Ortalama mutlak hata $ X $ ile $ \ hat {X} $ arasında hesaplanır.
$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$
$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$

Döngü kaybı

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

Yukarıda gösterildiği gibi, $ X $ görüntüsünü $ Y $ görüntüsüne çevirmekten $ G $ oluşturucu sorumludur. Kimlik kaybı, $ Y $ görüntüsünü $ G $ oluşturucuya beslediyseniz, $ Y $ gerçek görüntüsünü veya $ Y $ görüntüsüne yakın bir şeyi vermesi gerektiğini söylüyor.

Zebradan ata modeli bir at üzerinde veya attan zebra modelini bir zebra üzerinde çalıştırırsanız, görüntü zaten hedef sınıfı içerdiğinden görüntüyü fazla değiştirmemelidir.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Tüm jeneratörler ve ayırıcılar için optimize edicileri başlatın.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Kontrol noktaları

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

Eğitim

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Eğitim döngüsü karmaşık görünse de dört temel adımdan oluşur:

  • Tahminleri alın.
  • Kaybı hesaplayın.
  • Geri yayılımı kullanarak degradeleri hesaplayın.
  • Renk geçişlerini optimize ediciye uygulayın.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 167.14784979820251 sec

Test veri kümesini kullanarak oluşturun

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

Sonraki adımlar

Bu eğitim, Pix2Pix eğitiminde uygulanan jeneratör ve ayırıcıdan başlayarak CycleGAN'ın nasıl uygulanacağını göstermiştir. Bir sonraki adım olarak, TensorFlow Veri Kümelerinden farklı bir veri kümesi kullanmayı deneyebilirsiniz.

Ayrıca sonuçları iyileştirmek için daha fazla sayıda dönem için eğitim yapabilirsiniz veya burada kullanılan U-Net oluşturucu yerine kağıtta kullanılan değiştirilmiş ResNet oluşturucuyu uygulayabilirsiniz.