Deep Convolutional Generative Adversarial Network

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Dieses Tutorial zeigt, wie Sie Bilder von handgeschriebenen Ziffern mit einem Deep Convolutional Generative Adversarial Network (DCGAN) generieren. Der Code wird mit der Keras Sequential API mit einertf.GradientTape Trainingsschleife geschrieben.

Was sind GANs?

Generative Adversarial Networks (GANs) sind heute eine der interessantesten Ideen in der Informatik. Zwei Modelle werden gleichzeitig durch einen kontradiktorischen Prozess trainiert. Ein Generator ("der Künstler") lernt, Bilder zu erzeugen, die echt aussehen, während ein Diskriminator ("der Kunstkritiker") lernt, echte Bilder von Fälschungen zu unterscheiden.

Ein Diagramm eines Generators und Diskriminators

Während des Trainings wird der Generator immer besser darin, Bilder zu erzeugen, die echt aussehen, während der Diskriminator sie besser unterscheiden kann. Der Prozess erreicht ein Gleichgewicht, wenn der Diskriminator echte Bilder nicht mehr von Fälschungen unterscheiden kann.

Ein zweites Diagramm eines Generators und Diskriminators

Dieses Notebook demonstriert diesen Prozess für das MNIST-Dataset. Die folgende Animation zeigt eine Reihe von Bildern, die vom Generator erzeugt wurden, während er für 50 Epochen trainiert wurde. Die Bilder beginnen als zufälliges Rauschen und ähneln im Laufe der Zeit zunehmend handgeschriebenen Ziffern.

Beispielausgabe

Weitere Informationen zu GANs finden Sie im MIT-Kurs Intro to Deep Learning .

Einrichten

import tensorflow as tf
tf.__version__
'2.5.0'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Laden Sie den Datensatz und bereiten Sie ihn vor

Sie verwenden den MNIST-Datensatz, um den Generator und den Diskriminator zu trainieren. Der Generator generiert handgeschriebene Ziffern, die den MNIST-Daten ähneln.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Erstellen Sie die Modelle

Sowohl der Generator als auch der Diskriminator werden mit der Keras Sequential API definiert .

Der Generator

Der Generator verwendet tf.keras.layers.Conv2DTranspose (Upsampling) tf.keras.layers.Conv2DTranspose , um ein Bild aus einem Seed (zufälliges Rauschen) zu erzeugen. Beginnen Sie mit einer Dense Ebene, die diesen Seed als Eingabe verwendet, und führen Sie dann mehrmals ein Upsampling durch, bis Sie die gewünschte Bildgröße von 28 x 28 x 1 erreichen. Beachten Sie die tf.keras.layers.LeakyReLU Aktivierung für jede Schicht, mit Ausnahme der Ausgabeschicht, die tanh verwendet.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Verwenden Sie den (noch ungeübten) Generator, um ein Bild zu erstellen.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f7322b54fd0>

png

Der Diskriminierende

Der Diskriminator ist ein CNN-basierter Bildklassifikator.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Verwenden Sie den (noch ungeübten) Diskriminator, um die erzeugten Bilder als echt oder gefälscht zu klassifizieren. Das Modell wird trainiert, um positive Werte für echte Bilder und negative Werte für gefälschte Bilder auszugeben.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00085865]], shape=(1, 1), dtype=float32)

Definieren Sie den Verlust und die Optimierer

Definieren Sie Verlustfunktionen und Optimierer für beide Modelle.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Diskriminatorverlust

Dieses Verfahren quantifiziert, wie gut der Diskriminator echte Bilder von Fälschungen unterscheiden kann. Es vergleicht die Vorhersagen des Diskriminators zu echten Bildern mit einem Array von Einsen und die Vorhersagen des Diskriminators für gefälschte (erzeugte) Bilder mit einem Array von Nullen.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Generatorverlust

Der Verlust des Generators quantifiziert, wie gut er den Diskriminator austricksen konnte. Wenn der Generator gut funktioniert, wird der Diskriminator die gefälschten Bilder intuitiv als echt (oder 1) klassifizieren. Vergleichen Sie hier die Entscheidungen des Diskriminators über die erzeugten Bilder mit einem Array von Einsen.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Der Diskriminator und die Generatoroptimierer sind unterschiedlich, da Sie zwei Netze getrennt trainieren.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Prüfpunkte speichern

Dieses Notebook demonstriert auch das Speichern und Wiederherstellen von Modellen, was hilfreich sein kann, falls eine lange laufende Trainingsaufgabe unterbrochen wird.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Definiere die Trainingsschleife

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Die Trainingsschleife beginnt damit, dass der Generator einen zufälligen Startwert als Eingabe empfängt. Dieser Seed wird verwendet, um ein Bild zu erzeugen. Der Diskriminator wird dann verwendet, um reale Bilder (aus dem Trainingssatz gezogen) und gefälschte Bilder (vom Generator erzeugt) zu klassifizieren. Der Verlust wird für jedes dieser Modelle berechnet und die Gradienten werden verwendet, um den Generator und den Diskriminator zu aktualisieren.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Bilder erstellen und speichern

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Trainiere das Modell

Rufen Sie die oben definierte train() Methode auf, um den Generator und den Diskriminator gleichzeitig zu trainieren. Beachten Sie, dass das Training von GANs schwierig sein kann. Es ist wichtig, dass sich Generator und Diskriminator nicht gegenseitig überwältigen (z. B. mit ähnlicher Geschwindigkeit trainieren).

Zu Beginn des Trainings sehen die erzeugten Bilder wie zufälliges Rauschen aus. Mit fortschreitendem Training werden die generierten Ziffern immer echter aussehen. Nach etwa 50 Epochen ähneln sie MNIST-Ziffern. Dies kann mit den Standardeinstellungen von Colab etwa eine Minute / Epoche dauern.

train(train_dataset, EPOCHS)

png

Stellen Sie den neuesten Prüfpunkt wieder her.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f72983d2bd0>

Erstellen Sie ein GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Verwenden Sie imageio , um ein animiertes Gif mit den während des Trainings gespeicherten Bildern zu erstellen.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Nächste Schritte

Dieses Tutorial hat den vollständigen Code gezeigt, der zum Schreiben und Trainieren eines GAN erforderlich ist. Als nächsten Schritt möchten Sie vielleicht mit einem anderen Datensatz experimentieren, zum Beispiel mit dem auf Kaggle verfügbaren Large-scale Celeb Faces Attributes (CelebA) -Datensatz . Weitere Informationen zu GANs finden Sie im NIPS 2016-Tutorial: Generative Adversarial Networks .