Giúp bảo vệ Great Barrier Reef với TensorFlow trên Kaggle Tham Challenge

CycleGAN

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Máy tính xách tay này cho thấy hình ảnh lẻ để dịch hình ảnh sử dụng có điều kiện GAN, như được mô tả trong lẻ Dịch Hình ảnh-to-Image sử dụng Cycle-Phù hợp gây tranh cãi Networks , còn được gọi là CycleGAN. Bài báo đề xuất một phương pháp có thể nắm bắt các đặc điểm của một miền ảnh và tìm ra cách những đặc điểm này có thể được dịch sang miền ảnh khác, tất cả đều không có bất kỳ ví dụ đào tạo ghép nối nào.

Máy tính xách tay này giả định bạn đã quen thuộc với Pix2Pix, mà bạn có thể tìm hiểu về trong hướng dẫn Pix2Pix . Mã cho CycleGAN cũng tương tự, sự khác biệt chính là một chức năng mất bổ sung và việc sử dụng dữ liệu đào tạo chưa được ghép nối.

CycleGAN sử dụng mất tính nhất quán chu kỳ để cho phép đào tạo mà không cần dữ liệu được ghép nối. Nói cách khác, nó có thể dịch từ miền này sang miền khác mà không cần ánh xạ 1-1 giữa miền nguồn và miền đích.

Điều này mở ra khả năng thực hiện nhiều tác vụ thú vị như chỉnh sửa ảnh, chỉnh màu ảnh, chuyển kiểu, v.v. Tất cả những gì bạn cần là nguồn và tập dữ liệu đích (đơn giản là một thư mục hình ảnh).

Hình ảnh đầu ra 1Hình ảnh đầu ra 2

Thiết lập đường dẫn đầu vào

Cài đặt tensorflow_examples gói cho phép nhập khẩu của máy phát điện và bộ phân biệt.

pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

Đường ống đầu vào

Hướng dẫn này đào tạo một mô hình để dịch từ hình ảnh ngựa sang hình ảnh ngựa vằn. Bạn có thể tìm thấy dữ liệu này và những người tương tự ở đây .

Như đã đề cập trong bài báo , áp dụng jittering ngẫu nhiên và mirroring cho tập dữ liệu huấn luyện. Đây là một số kỹ thuật nâng cao hình ảnh để tránh trang bị quá nhiều.

Điều này tương tự với những gì đã được thực hiện trong pix2pix

  • Trong jittering ngẫu nhiên, hình ảnh được thay đổi kích cỡ để 286 x 286 và sau đó cắt một cách ngẫu nhiên để 256 x 256 .
  • Trong phản chiếu ngẫu nhiên, hình ảnh được lật ngẫu nhiên theo chiều ngang, tức là từ trái sang phải.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fd518202090>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fd5107cea90>

png

Nhập và sử dụng lại các mô hình Pix2Pix

Nhập khẩu các máy phát điện và bộ phân biệt được sử dụng trong Pix2Pix thông qua cài đặt tensorflow_examples gói.

Kiến trúc mô hình được sử dụng trong hướng dẫn này rất giống với những gì được sử dụng trong pix2pix . Một số khác biệt là:

Có 2 bộ tạo (G và F) và 2 bộ phân biệt (X và Y) đang được đào tạo ở đây.

  • Máy phát điện G học để chuyển đổi hình ảnh X để hình ảnh Y . \((G: X -> Y)\)
  • Generator F học để chuyển đổi hình ảnh Y để hình ảnh X . \((F: Y -> X)\)
  • Phân biệt D_X học để phân biệt giữa hình ảnh X và tạo ra hình ảnh X ( F(Y) ).
  • Phân biệt D_Y học để phân biệt giữa hình ảnh Y và hình ảnh được tạo ra Y ( G(X) ).

Mô hình cyclegan

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

Mất chức năng

Trong CycleGAN, không có cặp dữ liệu để đào tạo trên, do đó không có đảm bảo rằng các đầu vào x và mục tiêu y cặp có ý nghĩa trong quá trình đào tạo. Do đó, để thực thi rằng mạng học được ánh xạ chính xác, các tác giả đề xuất sự mất nhất quán chu trình.

Sự mất mát phân biệt và sự mất mát máy phát điện cũng tương tự như những người sử dụng trong pix2pix .

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Tính nhất quán của chu kỳ có nghĩa là kết quả phải gần với đầu vào ban đầu. Ví dụ: nếu một người dịch một câu từ tiếng Anh sang tiếng Pháp, và sau đó dịch ngược lại từ tiếng Pháp sang tiếng Anh, thì câu kết quả phải giống với câu gốc.

Trong chu kỳ mất nhất quán,

  • Hình ảnh \(X\) được truyền qua máy phát điện \(G\) rằng sản lượng tạo ra hình ảnh \(\hat{Y}\).
  • Tạo ra hình ảnh \(\hat{Y}\) được truyền qua máy phát điện \(F\) rằng sản lượng xoay vòng hình ảnh \(\hat{X}\).
  • Sai số tuyệt đối trung bình được tính toán giữa \(X\) và \(\hat{X}\).

\[forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}\]

\[backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}\]

Mất chu kỳ

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

Như đã trình bày ở trên, máy phát điện \(G\) có trách nhiệm chuyển hình ảnh \(X\) để hình ảnh \(Y\). Mất bản sắc nói rằng, nếu bạn ăn ảnh \(Y\) để phát \(G\), cần mang lại hình ảnh thực \(Y\) hoặc một cái gì đó gần gũi với hình ảnh \(Y\).

Nếu bạn chạy mô hình ngựa vằn trên ngựa hoặc mô hình ngựa vằn trên ngựa vằn, nó không nên sửa đổi hình ảnh nhiều vì hình ảnh đã chứa lớp đích.

\[Identity\ loss = |G(Y) - Y| + |F(X) - X|\]

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Khởi tạo trình tối ưu hóa cho tất cả các trình tạo và trình phân biệt.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Trạm kiểm soát

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

Đào tạo

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Mặc dù vòng lặp đào tạo trông phức tạp, nó bao gồm bốn bước cơ bản:

  • Nhận các dự đoán.
  • Tính toán sự mất mát.
  • Tính toán các gradient bằng cách sử dụng backpropagation.
  • Áp dụng các gradient cho trình tối ưu hóa.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 166.58266592025757 sec

Tạo bằng cách sử dụng tập dữ liệu thử nghiệm

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

Bước tiếp theo

Hướng dẫn này đã cho thấy làm thế nào để thực hiện CycleGAN bắt đầu từ các máy phát điện và phân biệt thực hiện trong Pix2Pix hướng dẫn. Bước tiếp theo, bạn có thể thử sử dụng một tập dữ liệu khác nhau từ TensorFlow Datasets .

Bạn cũng có thể đào tạo cho một số lượng lớn của thời đại để cải thiện kết quả, hoặc bạn có thể thực hiện các máy phát điện ResNet sửa đổi được sử dụng trong các giấy thay vì các máy phát điện U-Net sử dụng ở đây.