Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Jaringan Adversarial Generatif Konvolusional yang Dalam

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Tutorial ini mendemonstrasikan cara menghasilkan gambar digit tulisan tangan menggunakan Deep Convolutional Generative Adversarial Network (DCGAN). Kode ini ditulis menggunakan Keras Sequential API dengan loop pelatihan tf.GradientTape .

Apa itu GAN?

Generative Adversarial Networks (GANs) adalah salah satu ide paling menarik dalam ilmu komputer saat ini. Dua model dilatih secara bersamaan melalui proses permusuhan. Generator ("artis") belajar membuat gambar yang terlihat nyata, sedangkan diskriminator ("kritikus seni") belajar membedakan gambar nyata dari yang palsu.

Diagram generator dan diskriminator

Selama pelatihan, generator secara bertahap menjadi lebih baik dalam membuat gambar yang terlihat nyata, sementara diskriminator menjadi lebih baik dalam membedakannya. Proses tersebut mencapai keseimbangan ketika pembeda tidak dapat lagi membedakan gambar nyata dari gambar palsu.

Diagram kedua generator dan diskriminator

Notebook ini mendemonstrasikan proses ini pada set data MNIST. Animasi berikut menampilkan serangkaian gambar yang dihasilkan oleh generator saat dilatih selama 50 epoch. Gambar dimulai sebagai noise acak, dan semakin mirip dengan angka tulisan tangan dari waktu ke waktu.

keluaran sampel

Untuk mempelajari lebih lanjut tentang GAN, kami merekomendasikan kursus Intro to Deep Learning MIT.

Mempersiapkan

import tensorflow as tf
tf.__version__
'2.3.0'
# To generate GIFs
!pip install -q imageio
!pip install -q git+https://github.com/tensorflow/docs
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Muat dan siapkan dataset

Anda akan menggunakan dataset MNIST untuk melatih generator dan diskriminator. Generator akan menghasilkan angka tulisan tangan yang menyerupai data MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Buat modelnya

Baik generator dan diskriminator ditentukan menggunakan Keras Sequential API .

Generator

Generator menggunakan tf.keras.layers.Conv2DTranspose (upsampling) untuk menghasilkan gambar dari benih (gangguan acak). Mulailah dengan layer Dense yang mengambil seed ini sebagai input, kemudian lakukan upample beberapa kali hingga Anda mencapai ukuran gambar 28x28x1 yang diinginkan. Perhatikan aktivasi tf.keras.layers.LeakyReLU untuk setiap lapisan, kecuali lapisan keluaran yang menggunakan tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Gunakan generator (yang belum terlatih) untuk membuat gambar.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f2729b9f6d8>

png

Diskriminator

Diskriminator adalah pengklasifikasi gambar berbasis CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Gunakan diskriminator (yang belum terlatih) untuk mengklasifikasikan gambar yang dihasilkan sebagai nyata atau palsu. Model akan dilatih untuk menghasilkan nilai positif untuk gambar nyata, dan nilai negatif untuk gambar palsu.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.0003284]], shape=(1, 1), dtype=float32)

Tentukan kerugian dan pengoptimalan

Tentukan fungsi kerugian dan pengoptimal untuk kedua model.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Kerugian diskriminator

Metode ini mengukur seberapa baik diskriminator mampu membedakan gambar nyata dari yang palsu. Ini membandingkan prediksi diskriminator pada gambar nyata dengan array 1s, dan prediksi diskriminator pada gambar palsu (dihasilkan) dengan array 0s.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Kehilangan generator

Kehilangan generator mengukur seberapa baik ia mampu mengelabui diskriminator. Secara intuitif, jika generator berfungsi dengan baik, diskriminator akan mengklasifikasikan gambar palsu sebagai nyata (atau 1). Di sini, kami akan membandingkan keputusan diskriminator pada gambar yang dihasilkan dengan larik 1s.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Diskriminator dan pengoptimal generator berbeda karena kami akan melatih dua jaringan secara terpisah.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Simpan pos pemeriksaan

Notebook ini juga mendemonstrasikan cara menyimpan dan memulihkan model, yang dapat membantu jika tugas pelatihan yang berjalan lama terganggu.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Tentukan loop pelatihan

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Loop pelatihan dimulai dengan generator yang menerima seed acak sebagai input. Benih itu digunakan untuk menghasilkan gambar. Diskriminator kemudian digunakan untuk mengklasifikasikan gambar nyata (diambil dari set pelatihan) dan gambar palsu (diproduksi oleh generator). Kerugian dihitung untuk masing-masing model ini, dan gradien digunakan untuk memperbarui generator dan diskriminator.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Buat dan simpan gambar

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Latih modelnya

Panggil metode train() yang ditentukan di atas untuk melatih generator dan diskriminator secara bersamaan. Perhatikan, melatih GAN bisa jadi rumit. Penting agar generator dan diskriminator tidak saling mengalahkan (misalnya, mereka berlatih dengan kecepatan yang sama).

Pada awal pelatihan, citra yang dihasilkan terlihat seperti noise random. Saat pelatihan berlangsung, angka yang dihasilkan akan terlihat semakin nyata. Setelah sekitar 50 periode, mereka menyerupai angka MNIST. Ini mungkin membutuhkan waktu sekitar satu menit / epoch dengan pengaturan default di Colab.

train(train_dataset, EPOCHS)

png

Pulihkan pos pemeriksaan terbaru.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2729bc3128>

Buat GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Gunakan imageio untuk membuat gif animasi menggunakan gambar yang disimpan selama latihan.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Langkah selanjutnya

Tutorial ini telah menunjukkan kode lengkap yang diperlukan untuk menulis dan melatih GAN. Sebagai langkah selanjutnya, Anda mungkin ingin bereksperimen dengan kumpulan data yang berbeda, misalnya kumpulan data Large-scale Celeb Faces Attributes (CelebA) yang tersedia di Kaggle . Untuk mempelajari lebih lanjut tentang GAN, kami merekomendasikan Tutorial NIPS 2016: Generative Adversarial Networks .