דף זה תורגם על ידי Cloud Translation API.
Switch to English

רשת יריבות גנראטיבית עולמית עמוקה

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

הדרכה זו מדגימה כיצד ליצור תמונות של ספרות בכתב יד באמצעות רשת יריבה עמוקה (Generation Adversarial Generative) (DCGAN). הקוד נכתב באמצעות ה- Keras Sequential API עם לולאת אימוןtf.GradientTape .

מהם GANs?

רשתות יריבות גנריות (GAN) הן אחד הרעיונות המעניינים ביותר במדעי המחשב כיום. שני מודלים מאומנים בו זמנית על ידי תהליך יריב. גנרטור ("האמן") לומד ליצור תמונות שנראות אמיתיות, ואילו מפלה ("מבקר האמנות") לומד לספר תמונות אמיתיות מלבד זיופים.

תרשים של גנרטור ומפלה

במהלך האימון הגנרטור הופך בהדרגה לטוב יותר ביצירת תמונות שנראות אמיתיות, ואילו המפלה הופך טוב יותר בלבדל ביניהן. התהליך מגיע לשיווי משקל כאשר המפלה אינו יכול עוד להבדיל בין דימויים אמיתיים לבין זיופים.

תרשים שני של גנרטור ומפלה

מחברת זו מדגימה את התהליך הזה במערך הנתונים של MNIST. האנימציה הבאה מציגה סדרת תמונות שהופקה על ידי הגנרטור תוך כדי הכשרה של 50 תקופות. התמונות מתחילות כרעש אקראי, ודומות יותר ויותר לספרות בכתב יד לאורך זמן.

פלט לדוגמא

למידע נוסף על GAN, אנו ממליצים על קורס מבוא ללמידה עמוקה של MIT.

להכין

import tensorflow as tf
tf.__version__
'2.3.0'
# To generate GIFs
pip install -q imageio
pip install -q git+https://github.com/tensorflow/docs
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

טען והכין את מערך הנתונים

תשתמש במערך MNIST כדי להכשיר את הגנרטור ואת האפליה. הגנרטור יפיק ספרות בכתב יד הדומות לנתוני MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

צור את הדגמים

הן הגנרטור והן המפלה מוגדרים באמצעות ממשק ה- API של Keras Sequential .

המחולל

הגנרטור משתמש tf.keras.layers.Conv2DTranspose tf.keras.layers.Conv2DTranspose ( tf.keras.layers.Conv2DTranspose ups) כדי לייצר תמונה מזרע (רעש אקראי). התחל בשכבה Dense שלוקחת את הזרע הזה כקלט, ואז מעלה מספר פעמים עד שתגיע לגודל התמונה הרצוי של 28x28x1. שימו לב להפעלת tf.keras.layers.LeakyReLU עבור כל שכבה, למעט שכבת הפלט המשתמשת ב- tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

השתמש בגנרטור (שעדיין לא מאומן) ליצירת תמונה.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f2729b9f6d8>

png

המפלה

המפלה הוא מסווג תמונות מבוסס CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

השתמש במפלה (שעדיין לא מאומן) כדי לסווג את התמונות שנוצרו כאמיתיות או מזויפות. המודל יוכשר להפקת ערכים חיוביים לתמונות אמיתיות, וערכים שליליים לתמונות מזויפות.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.0003284]], shape=(1, 1), dtype=float32)

הגדר את האובדן והמיעול

הגדר פונקציות אובדן ואופטימיזציה לשני הדגמים.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

אובדן מפלה

שיטה זו מכמתת עד כמה המפלה מסוגל להבחין בין תמונות אמיתיות לזיופים. הוא משווה את תחזיות המפלה על תמונות אמיתיות למערך של 1s, ואת תחזיות המפלה על תמונות מזויפות (נוצרות) למערך של 0s.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

אובדן גנרטור

אובדן הגנרטור מכמת עד כמה הוא הצליח להונות את האפליה. באופן אינטואיטיבי, אם הגנרטור מתפקד היטב, המפלה יסווג את התמונות המזויפות לאמיתיות (או 1). כאן נשווה את החלטות המפלים על התמונות שנוצרו למערך של 1s.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

המפלה ומיטוב הגנרטורים שונים מכיוון שנכשיר שתי רשתות בנפרד.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

שמור מחסומים

מחברת זו מדגימה גם כיצד לשמור ולשחזר דגמים, דבר שיכול להועיל במקרה שמשימת אימונים ריצה ארוכה תיפסק.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

הגדר את לולאת האימון

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

לולאת האימונים מתחילה בגנרטור המקבל זרע אקראי כקלט. זרע זה משמש להפקת תמונה. לאחר מכן משתמשים באבחון לסיווג תמונות אמיתיות (שואבות מערכת האימונים) ומזייפות תמונות (שהופקו על ידי הגנרטור). ההפסד מחושב עבור כל אחד מהמודלים הללו, והשיפועים משמשים לעדכון הגנרטור והמפלה.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

צור ושמור תמונות

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

תאמן את המודל

התקשר לשיטת train() שהוגדרה לעיל כדי להכשיר את הגנרטור ואת המפלה בו זמנית. שימו לב, אימון GAN יכול להיות מסובך. חשוב שהגנרטור והמפלה לא יכריעו זה את זה (למשל, שהם יתאמנו בקצב דומה).

בתחילת האימון, התמונות שנוצרו נראות כמו רעש אקראי. ככל שמתקדמים באימונים, הספרות שנוצרו ייראו יותר ויותר אמיתיות. לאחר כ- 50 תקופות, הם דומים לספרות MNIST. פעולה זו עשויה לארוך כדקה / תקופה עם הגדרות ברירת המחדל ב- Colab.

train(train_dataset, EPOCHS)

png

שחזר את המחסום האחרון.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2729bc3128>

צור קובץ GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

השתמש ב- imageio כדי ליצור GIF מונפש באמצעות התמונות שנשמרו במהלך האימון.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

הצעדים הבאים

מדריך זה הראה את הקוד המלא הדרוש לכתיבה והכשרה של GAN. כשלב הבא, ייתכן שתרצה להתנסות במערך נתונים אחר, למשל מערך המאפיינים הגדול של Celeb Faces (CelebA) הזמין ב- Kaggle . למידע נוסף על GAN אנו ממליצים על הדרכת NIPS 2016: רשתות יריבות גנריות.