Merken Sie den Termin vor! Google I / O kehrt vom 18. bis 20. Mai zurück Registrieren Sie sich jetzt
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Deep Convolutional Generative Adversarial Network

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Dieses Tutorial zeigt, wie Bilder von handgeschriebenen Ziffern mithilfe eines Deep Convolutional Generative Adversarial Network (DCGAN) generiert werden. Der Code wird mithilfe der Keras Sequential API mit einertf.GradientTape Trainingsschleife geschrieben.

Was sind GANs?

Generative Adversarial Networks (GANs) sind heute eine der interessantesten Ideen in der Informatik. Zwei Modelle werden gleichzeitig durch einen kontroversen Prozess trainiert. Ein Generator ("der Künstler") lernt, Bilder zu erstellen, die real aussehen, während ein Diskriminator ("der Kunstkritiker") lernt, reale Bilder von Fälschungen zu unterscheiden.

Ein Diagramm eines Generators und eines Diskriminators

Während des Trainings kann der Generator nach und nach Bilder erstellen, die real aussehen, während der Diskriminator sie besser voneinander unterscheiden kann. Der Prozess erreicht ein Gleichgewicht, wenn der Diskriminator reale Bilder nicht mehr von Fälschungen unterscheiden kann.

Ein zweites Diagramm eines Generators und eines Diskriminators

Dieses Notizbuch demonstriert diesen Prozess im MNIST-Dataset. Die folgende Animation zeigt eine Reihe von Bildern, die vom Generator erzeugt wurden, als er für 50 Epochen trainiert wurde. Die Bilder beginnen als zufälliges Rauschen und ähneln im Laufe der Zeit zunehmend handgeschriebenen Ziffern.

Beispielausgabe

Weitere Informationen zu GANs finden Sie im MIT Intro to Deep Learning- Kurs.

Einrichten

import tensorflow as tf
tf.__version__
'2.4.1'
# To generate GIFs
pip install -q imageio
pip install -q git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Laden Sie den Datensatz und bereiten Sie ihn vor

Sie verwenden den MNIST-Datensatz, um den Generator und den Diskriminator zu trainieren. Der Generator generiert handschriftliche Ziffern, die den MNIST-Daten ähneln.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Erstellen Sie die Modelle

Sowohl der Generator als auch der Diskriminator werden mithilfe der Keras Sequential API definiert .

Der Generator

Der Generator verwendet tf.keras.layers.Conv2DTranspose (Upsampling), um ein Bild aus einem tf.keras.layers.Conv2DTranspose (zufälliges Rauschen) zu erzeugen. Beginnen Sie mit einer Dense Ebene, die diesen Startwert als Eingabe verwendet, und führen Sie dann mehrere Upsamples durch, bis Sie die gewünschte Bildgröße von 28 x 28 x 1 erreicht haben. Beachten Sie die Aktivierung von tf.keras.layers.LeakyReLU für jede Ebene, mit Ausnahme der Ausgabeebene, die tanh verwendet.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Verwenden Sie den (noch nicht geschulten) Generator, um ein Bild zu erstellen.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f3740747390>

png

Der Diskriminator

Der Diskriminator ist ein CNN-basierter Bildklassifizierer.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Verwenden Sie den (noch nicht geschulten) Diskriminator, um die generierten Bilder als echt oder falsch zu klassifizieren. Das Modell wird darauf trainiert, positive Werte für reale Bilder und negative Werte für gefälschte Bilder auszugeben.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00033125]], shape=(1, 1), dtype=float32)

Definieren Sie den Verlust und die Optimierer

Definieren Sie Verlustfunktionen und Optimierer für beide Modelle.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Diskriminatorverlust

Diese Methode quantifiziert, wie gut der Diskriminator in der Lage ist, reale Bilder von Fälschungen zu unterscheiden. Es vergleicht die Vorhersagen des Diskriminators für reale Bilder mit einem Array von 1s und die Vorhersagen des Diskriminators für gefälschte (erzeugte) Bilder mit einem Array von 0s.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Generatorverlust

Der Verlust des Generators quantifiziert, wie gut er den Diskriminator austricksen konnte. Wenn der Generator eine gute Leistung erbringt, klassifiziert der Diskriminator die gefälschten Bilder intuitiv als real (oder 1). Vergleichen Sie hier die Diskriminatorentscheidungen für die erzeugten Bilder mit einem Array von 1s.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Der Diskriminator und der Generatoroptimierer unterscheiden sich, da Sie zwei Netzwerke getrennt trainieren.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Kontrollpunkte speichern

Dieses Notizbuch zeigt auch, wie Modelle gespeichert und wiederhergestellt werden. Dies kann hilfreich sein, wenn eine lange Trainingsaufgabe unterbrochen wird.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Definieren Sie die Trainingsschleife

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Die Trainingsschleife beginnt damit, dass der Generator einen zufälligen Startwert als Eingabe empfängt. Dieser Samen wird verwendet, um ein Bild zu erzeugen. Der Diskriminator wird dann verwendet, um reale Bilder (aus dem Trainingssatz gezogen) zu klassifizieren und Bilder zu fälschen (vom Generator erzeugt). Der Verlust wird für jedes dieser Modelle berechnet und die Gradienten werden verwendet, um den Generator und den Diskriminator zu aktualisieren.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Bilder generieren und speichern

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Trainiere das Modell

Rufen Sie die oben definierte train() -Methode auf, um den Generator und den Diskriminator gleichzeitig zu trainieren. Beachten Sie, dass das Trainieren von GANs schwierig sein kann. Es ist wichtig, dass sich Generator und Diskriminator nicht gegenseitig überwältigen (z. B. dass sie mit einer ähnlichen Geschwindigkeit trainieren).

Zu Beginn des Trainings sehen die erzeugten Bilder wie zufälliges Rauschen aus. Mit fortschreitendem Training werden die generierten Ziffern immer realer. Nach etwa 50 Epochen ähneln sie MNIST-Ziffern. Dies kann bei den Standardeinstellungen von Colab etwa eine Minute / Epoche dauern.

train(train_dataset, EPOCHS)

png

Stellen Sie den neuesten Prüfpunkt wieder her.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f371f792c88>

Erstellen Sie ein GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Verwenden Sie imageio , um ein animiertes GIF mit den während des Trainings gespeicherten Bildern zu erstellen.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Nächste Schritte

Dieses Tutorial zeigt den vollständigen Code, der zum Schreiben und Trainieren eines GAN erforderlich ist. Als nächsten Schritt möchten Sie möglicherweise mit einem anderen Datensatz experimentieren, z. B. dem auf Kaggle verfügbaren CelebA- Datensatz (Large-Scale Celeb Faces Attributes). Weitere Informationen zu GANs finden Sie im NIPS 2016-Lernprogramm: Generative Adversarial Networks .