![]() | ![]() | ![]() | ![]() |
このノートブックは、CycleGANとしても知られるCycle-Consistent Adversarial Networksを使用した対になっていない画像から画像への変換で説明されているように、条件付きGANを使用した対になっていない画像から画像への変換を示します。この論文では、ペアのトレーニング例がない場合に、1つの画像ドメインの特性をキャプチャし、これらの特性を別の画像ドメインに変換する方法を理解できる方法を提案します。
このノートブックは、 Pix2Pixチュートリアルで学習できるPix2Pixに精通していることを前提としています。 CycleGANのコードは類似しており、主な違いは、追加の損失関数と、対になっていないトレーニングデータの使用です。
CycleGANは、サイクルの一貫性の喪失を使用して、ペアのデータを必要とせずにトレーニングを可能にします。つまり、ソースドメインとターゲットドメインを1対1でマッピングしなくても、あるドメインから別のドメインに変換できます。
これにより、写真のエンハンスメント、画像のカラー化、スタイルの転送など、多くの興味深いタスクを実行できるようになります。必要なのは、ソースとターゲットのデータセット(単なる画像のディレクトリ)だけです。
入力パイプラインを設定します
ジェネレーターとディスクリミネーターのインポートを可能にするtensorflow_examplesパッケージをインストールします。
pip install -q git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
AUTOTUNE = tf.data.AUTOTUNE
入力パイプライン
このチュートリアルでは、馬の画像からシマウマの画像に変換するモデルをトレーニングします。このデータセットと同様のデータセットはここにあります。
論文で述べたように、トレーニングデータセットにランダムジッターとミラーリングを適用します。これらは、過剰適合を回避する画像拡張技術の一部です。
- ランダムジッターでは、画像のサイズが
286 x 286
変更されてから、ランダムに256 x 256
x256にトリミングされます。 - ランダムミラーリングでは、画像はランダムに水平方向、つまり左から右に反転します。
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
Downloading and preparing dataset 111.45 MiB (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0... Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
cropped_image = tf.image.random_crop(
image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def random_jitter(image):
# resizing to 286 x 286 x 3
image = tf.image.resize(image, [286, 286],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# randomly cropping to 256 x 256 x 3
image = random_crop(image)
# random mirroring
image = tf.image.random_flip_left_right(image)
return image
def preprocess_image_train(image, label):
image = random_jitter(image)
image = normalize(image)
return image
def preprocess_image_test(image, label):
image = normalize(image)
return image
train_horses = train_horses.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fb9682d80f0>
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fb968237208>
Pix2Pixモデルをインポートして再利用する
インストールされているtensorflow_examplesパッケージを介して、 Pix2Pixで使用されるジェネレーターとディスクリミネーターをインポートします。
このチュートリアルで使用されるモデルアーキテクチャは、 pix2pixで使用されたものと非常によく似ています。いくつかの違いは次のとおりです。
- Cycleganは、バッチ正規化の代わりにインスタンス正規化を使用します。
- CycleGANペーパーは、変更された
resnet
ベースのジェネレーターを使用します。このチュートリアルでは、簡単にするために変更されたunet
ジェネレーターを使用しています。
ここでは、2つのジェネレーター(GとF)と2つのディスクリミネーター(XとY)がトレーニングされています。
- ジェネレータ
G
は、画像X
を画像Y
に変換することを学習します。 $(G:X-> Y)$ - ジェネレータ
F
は、画像Y
を画像X
に変換することを学習します。 $(F:Y-> X)$ - 弁別
D_X
画像を区別することを学習X
と生成された画像X
(F(Y)
- 弁別
D_Y
は、画像Y
と生成された画像Y
(G(X)
)をD_Y
ことを学習します。
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8
imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']
for i in range(len(imgs)):
plt.subplot(2, 2, i+1)
plt.title(title[i])
if i % 2 == 0:
plt.imshow(imgs[i][0] * 0.5 + 0.5)
else:
plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
plt.figure(figsize=(8, 8))
plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')
plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')
plt.show()
損失関数
CycleGANには、トレーニングするペアのデータがないため、トレーニング中に入力x
とターゲットy
ペアが意味を持つという保証はありません。したがって、ネットワークが正しいマッピングを学習することを強制するために、著者はサイクルの一貫性の喪失を提案します。
弁別器の損失とジェネレータの損失は、 pix2pixで使用されているものと同様です。
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
real_loss = loss_obj(tf.ones_like(real), real)
generated_loss = loss_obj(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss * 0.5
def generator_loss(generated):
return loss_obj(tf.ones_like(generated), generated)
サイクルの一貫性は、結果が元の入力に近いはずであることを意味します。たとえば、ある文を英語からフランス語に翻訳してから、それをフランス語から英語に戻す場合、結果の文は元の文と同じになります。
サイクルの一貫性の喪失では、
- 画像$ X $は、生成された画像$ \ hat {Y} $を生成するジェネレータ$ G $を介して渡されます。
- 生成された画像$ \ hat {Y} $は、循環画像$ \ hat {X} $を生成するジェネレータ$ F $を介して渡されます。
- 平均絶対誤差は、$ X $と$ \ hat {X} $の間で計算されます。
def calc_cycle_loss(real_image, cycled_image):
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
return LAMBDA * loss1
上に示したように、ジェネレータ$ G $は、画像$ X $を画像$ Y $に変換する役割を果たします。 IDの喪失によると、画像$ Y $をジェネレーター$ G $にフィードすると、実際の画像$ Y $または画像$ Y $に近いものが生成されます。
馬でゼブラから馬へのモデル、またはゼブラで馬からゼブラへのモデルを実行する場合、画像にはすでにターゲットクラスが含まれているため、画像をあまり変更しないでください。
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return LAMBDA * 0.5 * loss
すべてのジェネレーターとディスクリミネーターのオプティマイザーを初期化します。
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
チェックポイント
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
トレーニング
EPOCHS = 40
def generate_images(model, test_input):
prediction = model(test_input)
plt.figure(figsize=(12, 12))
display_list = [test_input[0], prediction[0]]
title = ['Input Image', 'Predicted Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i])
# getting the pixel values between [0, 1] to plot it.
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
トレーニングループは複雑に見えますが、4つの基本的なステップで構成されています。
- 予測を取得します。
- 損失を計算します。
- バックプロパゲーションを使用して勾配を計算します。
- グラデーションをオプティマイザーに適用します。
@tf.function
def train_step(real_x, real_y):
# persistent is set to True because the tape is used more than
# once to calculate the gradients.
with tf.GradientTape(persistent=True) as tape:
# Generator G translates X -> Y
# Generator F translates Y -> X.
fake_y = generator_g(real_x, training=True)
cycled_x = generator_f(fake_y, training=True)
fake_x = generator_f(real_y, training=True)
cycled_y = generator_g(fake_x, training=True)
# same_x and same_y are used for identity loss.
same_x = generator_f(real_x, training=True)
same_y = generator_g(real_y, training=True)
disc_real_x = discriminator_x(real_x, training=True)
disc_real_y = discriminator_y(real_y, training=True)
disc_fake_x = discriminator_x(fake_x, training=True)
disc_fake_y = discriminator_y(fake_y, training=True)
# calculate the loss
gen_g_loss = generator_loss(disc_fake_y)
gen_f_loss = generator_loss(disc_fake_x)
total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
# Total generator loss = adversarial loss + cycle loss
total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
# Calculate the gradients for generator and discriminator
generator_g_gradients = tape.gradient(total_gen_g_loss,
generator_g.trainable_variables)
generator_f_gradients = tape.gradient(total_gen_f_loss,
generator_f.trainable_variables)
discriminator_x_gradients = tape.gradient(disc_x_loss,
discriminator_x.trainable_variables)
discriminator_y_gradients = tape.gradient(disc_y_loss,
discriminator_y.trainable_variables)
# Apply the gradients to the optimizer
generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
generator_g.trainable_variables))
generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
generator_f.trainable_variables))
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
discriminator_x.trainable_variables))
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
start = time.time()
n = 0
for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
train_step(image_x, image_y)
if n % 10 == 0:
print ('.', end='')
n+=1
clear_output(wait=True)
# Using a consistent image (sample_horse) so that the progress of the model
# is clearly visible.
generate_images(generator_g, sample_horse)
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8 Time taken for epoch 40 is 169.26107788085938 sec
テストデータセットを使用して生成
# Run the trained model on the test dataset
for inp in test_horses.take(5):
generate_images(generator_g, inp)
次のステップ
このチュートリアルでは、 Pix2Pixチュートリアルで実装されたジェネレーターとディスクリミネーターから始めてCycleGANを実装する方法を示しました。次のステップとして、 TensorFlowデータセットとは異なるデータセットを使用してみることができます。
結果を改善するために、より多くのエポックをトレーニングすることもできます。または、ここで使用するU-Netジェネレーターの代わりに、ペーパーで使用する変更されたResNetジェネレーターを実装することもできます。