רשת יריבות גנראטיבית עולמית

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

מדריך זה מדגים כיצד ליצור תמונות של ספרות בכתב יד באמצעות רשת יריבה עמוקה (Generation Adversarial Generative) (DCGAN). הקוד נכתב באמצעות ממשק ה- API של Keras Sequential עם לולאת אימוןtf.GradientTape .

מהם GANs?

רשתות יריבות גנריות (GAN) הן אחד הרעיונות המעניינים ביותר במדעי המחשב כיום. שני מודלים מאומנים בו זמנית על ידי תהליך יריב. גנרטור ("האמן") לומד ליצור תמונות שנראות אמיתיות, ואילו מפלה ("מבקר האמנות") לומד לספר תמונות אמיתיות מלבד זיופים.

תרשים של גנרטור ומפלה

במהלך האימון, המחולל הופך בהדרגה ליצירת תמונות שנראות אמיתיות, ואילו המפלה הופך טוב יותר בלבדל ביניהן. התהליך מגיע לשיווי משקל כאשר המפלה אינו יכול עוד להבדיל בין דימויים אמיתיים לבין זיופים.

תרשים שני של גנרטור ומפלה

מחברת זו מדגימה תהליך זה במערך הנתונים MNIST. האנימציה הבאה מציגה סדרת תמונות שהופקה על ידי הגנרטור תוך כדי הכשרה של 50 עידנים. התמונות מתחילות כרעש אקראי, ודומות יותר ויותר לספרות בכתב יד לאורך זמן.

פלט לדוגמא

למידע נוסף על GAN, עיין בקורס מבוא ללמידה עמוקה של MIT.

להכין

import tensorflow as tf
tf.__version__
'2.5.0'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

טען והכין את מערך הנתונים

תשתמש במערך MNIST כדי להכשיר את הגנרטור ואת האפליה. הגנרטור יפיק ספרות בכתב יד הדומות לנתוני MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

צור את הדגמים

הן הגנרטור והן המפלה מוגדרים באמצעות ה- API של Keras Sequential .

המחולל

הגנרטור משתמש tf.keras.layers.Conv2DTranspose tf.keras.layers.Conv2DTranspose ( tf.keras.layers.Conv2DTranspose אפס) כדי לייצר תמונה מזרע (רעש אקראי). התחל עם שכבה Dense שלוקחת את הזרע הזה כקלט, ואז מעלה מספר פעמים עד שתגיע לגודל התמונה הרצוי של 28x28x1. שימו לב להפעלת tf.keras.layers.LeakyReLU עבור כל שכבה, למעט שכבת הפלט המשתמשת ב- tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

השתמש בגנרטור (שעדיין לא מאומן) ליצירת תמונה.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f7322b54fd0>

png

המפלה

המפלה הוא מסווג תמונות מבוסס CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

השתמש במפלה (שעדיין לא מאומן) כדי לסווג את התמונות שנוצרו כאמיתיות או מזויפות. המודל יוכשר להפקת ערכים חיוביים לתמונות אמיתיות, וערכים שליליים לתמונות מזויפות.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00085865]], shape=(1, 1), dtype=float32)

הגדר את האובדן והמיעול

הגדר פונקציות אובדן ומייטיבים לשני הדגמים.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

אובדן מפלה

שיטה זו מכמתת עד כמה המפלה מסוגל להבחין בין תמונות אמיתיות לזיופים. הוא משווה את תחזיות המפלה על תמונות אמיתיות למערך של 1s, ואת תחזיות המפלה על תמונות מזויפות (שנוצרו) למערך של 0s.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

אובדן גנרטור

אובדן הגנרטור מכמת עד כמה הוא הצליח להונות את האפליה. באופן אינטואיטיבי, אם המחולל מתפקד היטב, המפלה יסווג את התמונות המזויפות לאמיתיות (או 1). כאן, השווה את החלטות המפלים על התמונות שנוצרו למערך של 1s.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

המפלה ומיטוב הגנרטורים שונים מכיוון שתאמן שתי רשתות בנפרד.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

שמור מחסומים

מחברת זו מדגימה גם כיצד לשמור ולשחזר דגמים, דבר שיכול להועיל במקרה שמשימת אימון ארוכה תיפסק.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

הגדר את לולאת האימון

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

לולאת האימון מתחילה בגנרטור המקבל זרע אקראי כקלט. זרע זה משמש להפקת תמונה. לאחר מכן משתמשים באבחון לסיווג תמונות אמיתיות (שנלקחו מערכת האימונים) ומזייף תמונות (שהופקו על ידי הגנרטור). ההפסד מחושב עבור כל אחד מהמודלים הללו, והשיפועים משמשים לעדכון הגנרטור והמפלה.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

צור ושמור תמונות

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

תאמן את המודל

התקשר לשיטת train() שהוגדרה לעיל כדי להכשיר את הגנרטור ואת המפלה בו זמנית. שימו לב, אימוני GAN יכולים להיות מסובכים. חשוב שהגנרטור והמפלה לא יכריעו זה את זה (למשל, שהם יתאמנו בקצב דומה).

בתחילת האימון, התמונות שנוצרו נראות כמו רעש אקראי. עם התקדמות האימונים, הספרות שנוצרו ייראו יותר ויותר אמיתיות. לאחר כ- 50 תקופות, הם דומים לספרות MNIST. פעולה זו עשויה לקחת בערך דקה / תקופה עם הגדרות ברירת המחדל ב- Colab.

train(train_dataset, EPOCHS)

png

שחזר את המחסום האחרון.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f72983d2bd0>

צור קובץ GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

השתמש ב- imageio כדי ליצור GIF מונפש באמצעות התמונות שנשמרו במהלך האימון.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

הצעדים הבאים

מדריך זה הראה את הקוד המלא הדרוש לכתיבה והכשרה של GAN. כשלב הבא, ייתכן שתרצה להתנסות עם מערך נתונים אחר, למשל מערך המאפיינים של Celeb Faces בקנה מידה גדול (CelebA) הזמין ב- Kaggle . למידע נוסף על GAN, עיין במדריך NIPS 2016: רשתות יריבות גנריות.