このページは Cloud Translation API によって翻訳されました。
Switch to English

深い畳み込み生成的敵対的ネットワーク

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

このチュートリアルでは、 Deep Convolutional Generative Adversarial Network (DCGAN)を使用して手書き数字の画像を生成する方法を示します。コードは、 tf.GradientTapeトレーニングループでtf.GradientTape Sequential APIを使用して記述されています。

GANとは何ですか?

Generative Adversarial Network (GAN)は、今日のコンピュータサイエンスで最も興味深いアイデアの1つです。 2つのモデルが敵対的なプロセスによって同時にトレーニングされます。 ジェネレーター (「アーティスト」)は本物に見える画像を作成することを学び、 弁別子 (「芸術評論家」)は本物の画像を偽物と区別することを学びます。

ジェネレーターとディスクリミネーターの図

トレーニング中、 ジェネレーターは徐々にリアルに見えるイメージを作成するのが得意になり、ディスクリミネーターはイメージを区別するのが得意になります。 弁別担当者が実際の画像と偽物を区別できなくなると、プロセスは平衡状態になります。

ジェネレーターとディスクリミネーターの2番目の図

このノートブックは、MNISTデータセットでこのプロセスを示しています。次のアニメーションは、50エポックでトレーニングされたジェネレーターによって生成された一連の画像を示しています。画像はランダムノイズとして始まり、時間の経過とともに手書きの数字にますます似ています。

出力例

GANの詳細については、MITの深層学習入門コースをお勧めします。

セットアップ

import tensorflow as tf
tf.__version__
'2.3.0'
# To generate GIFs
!pip install -q imageio
!pip install -q git+https://github.com/tensorflow/docs
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

データセットを読み込んで準備する

MNISTデータセットを使用して、ジェネレーターと弁別子をトレーニングします。ジェネレーターは、MNISTデータに似た手書きの数字を生成します。

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

モデルを作成する

ジェネレーターとディスクリミネーターの両方は、 Keras Sequential APIを使用して定義されます

ジェネレーター

ジェネレーターは、 tf.keras.layers.Conv2DTranspose (アップサンプリング)レイヤーを使用して、シード(ランダムノイズ)から画像を生成します。このシードを入力として使用するDenseレイヤーから開始し、希望の画像サイズ28x28x1に達するまで数回アップサンプリングします。 tanhを使用する出力層を除いて、各層のtf.keras.layers.LeakyReLUアクティブ化に注意してtf.keras.layers.LeakyReLU

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

(まだトレーニングされていない)ジェネレーターを使用して画像を作成します。

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f2729b9f6d8>

png

弁別者

弁別器はCNNベースの画像分類器です。

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

(まだ訓練されていない)弁別器を使用して、生成された画像を本物または偽物として分類します。モデルは、実際の画像では正の値を、偽の画像では負の値を出力するようにトレーニングされます。

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.0003284]], shape=(1, 1), dtype=float32)

損失とオプティマイザを定義する

両方のモデルの損失関数とオプティマイザを定義します。

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

弁別器の損失

この方法は、識別器が実際の画像と偽物をどの程度うまく区別できるかを定量化します。実際の画像に対する弁別子の予測を1の配列と比較し、偽の(生成された)画像に対する弁別器の予測を0の配列と比較します。

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

発電機の損失

ジェネレーターの損失は、ディスクリミネーターをだますことができた程度を数値化します。直感的に、ジェネレーターが適切に機能している場合、ディスクリミネーターは偽のイメージを実(または1)として分類します。ここでは、生成された画像に対する判別子の決定を1の配列と比較します。

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

2つのネットワークを個別にトレーニングするため、弁別子と生成器のオプティマイザーは異なります。

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

チェックポイントを保存

このノートブックは、モデルの保存と復元の方法も示しています。これは、長時間実行されるトレーニングタスクが中断された場合に役立ちます。

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

トレーニングループを定義する

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

トレーニングループは、ジェネレータがランダムシードを入力として受け取ることから始まります。そのシードは画像を生成するために使用されます。次に、識別器を使用して、(トレーニングセットから描画された)実際の画像と(ジェネレーターによって生成された)偽の画像を分類します。これらのモデルごとに損失が計算され、勾配を使用してジェネレーターとディスクリミネーターが更新されます。

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

画像を生成して保存する

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

モデルをトレーニングする

上記で定義したtrain()メソッドを呼び出して、ジェネレーターとディスクリミネーターを同時にトレーニングします。 GANのトレーニングは注意が必要です。ジェネレーターとディスクリミネーターがお互いを圧倒しないことが重要です(たとえば、同じような速度でトレーニングすること)。

トレーニングの開始時に、生成された画像はランダムノイズのように見えます。トレーニングが進むにつれて、生成される数字はますますリアルに見えます。約50エポック後、それらはMNISTディジットに似ています。 Colabのデフォルト設定では、これに約1分/エポックがかかる場合があります。

train(train_dataset, EPOCHS)

png

最新のチェックポイントを復元します。

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2729bc3128>

GIFを作成する

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

imageioを使用して、トレーニング中に保存された画像を使用してアニメーションGIFを作成します。

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

次のステップ

このチュートリアルでは、GANの記述とトレーニングに必要な完全なコードを示しました。次のステップとして、Kaggleで利用できる大規模なCeleb Faces属性(CelebA)データセットなど、別のデータセットを試してみることができます 。 GANの詳細については、 NIPS 2016チュートリアル:Generative Adversarial Networksをお勧めします。