![]() | ![]() | ![]() | ![]() |
הדרכה זו מדגימה כיצד ליצור תמונות של ספרות בכתב יד באמצעות רשת יריבה עמוקה (Generation Adversarial Generative) (DCGAN). הקוד נכתב באמצעות ה- Keras Sequential API עם לולאת אימוןtf.GradientTape
.
מהם GANs?
רשתות יריבות גנריות (GAN) הן אחד הרעיונות המעניינים ביותר במדעי המחשב כיום. שני מודלים מאומנים בו זמנית על ידי תהליך יריב. גנרטור ("האמן") לומד ליצור תמונות שנראות אמיתיות, ואילו מפלה ("מבקר האמנות") לומד לספר תמונות אמיתיות מלבד זיופים.
במהלך האימון, המחולל הופך בהדרגה לטוב יותר ביצירת תמונות שנראות אמיתיות, ואילו המפלה הופך להיות טוב יותר בלבדל ביניהן. התהליך מגיע לשיווי משקל כאשר המפלה אינו יכול עוד להבדיל בין דימויים אמיתיים לבין זיופים.
מחברת זו מדגימה את התהליך הזה במערך הנתונים של MNIST. האנימציה הבאה מציגה סדרת תמונות שהופקה על ידי הגנרטור תוך כדי הכשרה של 50 תקופות. התמונות מתחילות כרעש אקראי, ודומות יותר ויותר לספרות בכתב יד לאורך זמן.
למידע נוסף על GAN, אנו ממליצים על קורס מבוא ללמידה עמוקה של MIT.
להכין
import tensorflow as tf
tf.__version__
'2.3.0'
# To generate GIFs
pip install -q imageio
pip install -q git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
טען והכין את מערך הנתונים
תשתמש במערך MNIST כדי להכשיר את הגנרטור ואת האפליה. הגנרטור יפיק ספרות בכתב יד הדומות לנתוני MNIST.
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
צור את הדגמים
הן הגנרטור והן המפלה מוגדרים באמצעות ממשק ה- API של Keras Sequential .
המחולל
הגנרטור משתמש tf.keras.layers.Conv2DTranspose
tf.keras.layers.Conv2DTranspose ( tf.keras.layers.Conv2DTranspose
ups) כדי לייצר תמונה מזרע (רעש אקראי). התחל עם שכבה Dense
שלוקחת את הזרע הזה כקלט, ואז מעלה מספר פעמים עד שתגיע לגודל התמונה הרצוי של 28x28x1. שימו לב להפעלת tf.keras.layers.LeakyReLU
עבור כל שכבה, למעט שכבת הפלט המשתמשת ב- tanh.
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
השתמש בגנרטור (שעדיין לא מאומן) ליצירת תמונה.
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f2729b9f6d8>
המפלה
המפלה הוא מסווג תמונות מבוסס CNN.
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
השתמש במפלה (שעדיין לא מאומן) כדי לסווג את התמונות שנוצרו כאמיתיות או מזויפות. המודל יוכשר להפקת ערכים חיוביים לתמונות אמיתיות, וערכים שליליים לתמונות מזויפות.
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.0003284]], shape=(1, 1), dtype=float32)
הגדר את האובדן והמיעול
הגדר פונקציות אובדן ואופטימיזציה לשני הדגמים.
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
אובדן מפלה
שיטה זו מכמתת עד כמה המפלה מסוגל להבחין בין תמונות אמיתיות לזיופים. הוא משווה את תחזיות המפלה על תמונות אמיתיות למערך של 1s, ואת תחזיות המפלה על תמונות מזויפות (שנוצרו) למערך של 0s.
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
אובדן גנרטור
אובדן הגנרטור מכמת עד כמה הוא הצליח להונות את האפליה. באופן אינטואיטיבי, אם המחולל מתפקד היטב, המפלה יסווג את התמונות המזויפות לאמיתיות (או 1). כאן נשווה את החלטות המפלים על התמונות שנוצרו למערך של 1s.
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
המפלה ומיטוב הגנרטורים שונים מכיוון שנכשיר שתי רשתות בנפרד.
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
שמור מחסומים
מחברת זו מדגימה גם כיצד לשמור ולשחזר דגמים, דבר שיכול להועיל במקרה שמשימת אימונים ריצה ארוכה תיפסק.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
הגדר את לולאת האימון
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])
לולאת האימונים מתחילה בגנרטור המקבל זרע אקראי כקלט. זרע זה משמש להפקת תמונה. לאחר מכן משתמשים באבחון לסיווג תמונות אמיתיות (שנשאבו ממערך האימונים) ומזייף תמונות (שהופקו על ידי הגנרטור). ההפסד מחושב עבור כל אחד מהמודלים הללו, והשיפועים משמשים לעדכון הגנרטור והמפלה.
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
צור ושמור תמונות
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
תאמן את המודל
התקשר לשיטת train()
שהוגדרה לעיל כדי להכשיר את הגנרטור ואת המפלה בו זמנית. שימו לב, אימון GAN יכול להיות מסובך. חשוב שהגנרטור והמפלה לא יכריעו זה את זה (למשל, שהם יתאמנו בקצב דומה).
בתחילת האימון, התמונות שנוצרו נראות כמו רעש אקראי. ככל שמתקדמים באימונים, הספרות שנוצרו ייראו יותר ויותר אמיתיות. לאחר כ- 50 תקופות, הם דומים לספרות MNIST. פעולה זו עשויה להימשך בערך דקה / תקופה עם הגדרות ברירת המחדל ב- Colab.
train(train_dataset, EPOCHS)
שחזר את המחסום האחרון.
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2729bc3128>
צור קובץ GIF
# Display a single image using the epoch number
def display_image(epoch_no):
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)
השתמש ב- imageio
כדי ליצור GIF מונפש באמצעות התמונות שנשמרו במהלך האימון.
anim_file = 'dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)
הצעדים הבאים
מדריך זה הראה את הקוד המלא הדרוש לכתיבה והכשרה של GAN. כשלב הבא, אולי תרצה להתנסות במערך נתונים אחר, למשל מערך המאפיינים של Celeb Faces בקנה מידה גדול (CelebA) הזמין ב- Kaggle . למידע נוסף על GAN אנו ממליצים על הדרכת NIPS 2016: רשתות יריבות גנריות.