บันทึกวันที่! Google I / O ส่งคืนวันที่ 18-20 พฤษภาคม ลงทะเบียนตอนนี้
หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

Deep Convolutional Generative Adversarial Network

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

บทช่วยสอนนี้สาธิตวิธีการสร้างภาพของตัวเลขที่เขียนด้วยลายมือโดยใช้ Deep Convolutional Generative Adversarial Network (DCGAN) โค้ดนี้เขียนโดยใช้ Keras Sequential API พร้อมกับtf.GradientTape training loop

GAN คืออะไร?

Generative Adversarial Networks (GAN) เป็นหนึ่งในแนวคิดที่น่าสนใจที่สุดในวิทยาการคอมพิวเตอร์ในปัจจุบัน สองรุ่นได้รับการฝึกฝนพร้อมกันโดยกระบวนการที่เป็นปฏิปักษ์ เครื่องกำเนิดไฟฟ้า ("ศิลปิน") เรียนรู้ที่จะสร้างภาพที่ดูเหมือนจริงในขณะที่ผู้ แยกแยะ ("นักวิจารณ์ศิลปะ") เรียนรู้ที่จะบอกภาพจริงนอกเหนือจากของปลอม

แผนภาพของเครื่องกำเนิดและตัวแยกแยะ

ในระหว่างการฝึกอบรม เครื่องกำเนิดไฟฟ้า จะพัฒนาภาพที่ดูเหมือนจริงได้ดีขึ้นเรื่อย ๆ ในขณะที่ผู้ แยกแยะ จะ แยกแยะได้ ดีขึ้น กระบวนการนี้ถึงจุดสมดุลเมื่อผู้ แยกแยะ ไม่สามารถแยกแยะภาพจริงจากของปลอมได้อีกต่อไป

แผนภาพที่สองของเครื่องกำเนิดไฟฟ้าและตัวแยกแยะ

สมุดบันทึกนี้สาธิตกระบวนการนี้บนชุดข้อมูล MNIST ภาพเคลื่อนไหวต่อไปนี้แสดงชุดภาพที่สร้างโดย เครื่องกำเนิดไฟฟ้า ซึ่งได้รับการฝึกฝนมาเป็นเวลา 50 ยุค ภาพเริ่มเป็นสัญญาณรบกวนแบบสุ่มและมีลักษณะคล้ายตัวเลขที่เขียนด้วยมือมากขึ้นเรื่อย ๆ เมื่อเวลาผ่านไป

เอาท์พุทตัวอย่าง

หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับ GAN โปรดดูหลักสูตร Intro to Deep Learning ของ MIT

ติดตั้ง

import tensorflow as tf
tf.__version__
'2.4.1'
# To generate GIFs
pip install -q imageio
pip install -q git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

โหลดและเตรียมชุดข้อมูล

คุณจะใช้ชุดข้อมูล MNIST เพื่อฝึกเครื่องกำเนิดไฟฟ้าและตัวแยกแยะ เครื่องกำเนิดไฟฟ้าจะสร้างตัวเลขที่เขียนด้วยลายมือซึ่งคล้ายกับข้อมูล MNIST

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

สร้างแบบจำลอง

ทั้งตัวสร้างและตัวแยกแยะถูกกำหนดโดยใช้ Keras Sequential API

เครื่องกำเนิดไฟฟ้า

เครื่องกำเนิดไฟฟ้าใช้ tf.keras.layers.Conv2DTranspose (การ tf.keras.layers.Conv2DTranspose ) เพื่อสร้างภาพจากเมล็ดพันธุ์ (สัญญาณรบกวนแบบสุ่ม) เริ่มต้นด้วยเลเยอร์ Dense ที่รับเมล็ดพันธุ์นี้เป็นอินพุตจากนั้นเพิ่มตัวอย่างหลาย ๆ ครั้งจนกว่าคุณจะได้ขนาดภาพที่ต้องการ 28x28x1 สังเกตการเปิดใช้งาน tf.keras.layers.LeakyReLU สำหรับแต่ละเลเยอร์ยกเว้นเลเยอร์เอาต์พุตที่ใช้ tanh

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

ใช้เครื่องกำเนิดไฟฟ้า (ที่ยังไม่ได้รับการฝึกฝน) เพื่อสร้างภาพ

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f3740747390>

png

ผู้เลือกปฏิบัติ

ตัวจำแนกเป็นตัวจำแนกรูปภาพที่ใช้ CNN

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

ใช้ตัวเลือก (ที่ยังไม่ได้รับการฝึกฝน) เพื่อจำแนกภาพที่สร้างขึ้นว่าเป็นของจริงหรือของปลอม โมเดลจะได้รับการฝึกให้แสดงค่าบวกสำหรับรูปภาพจริงและค่าลบสำหรับรูปภาพปลอม

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00033125]], shape=(1, 1), dtype=float32)

กำหนดความสูญเสียและเครื่องมือเพิ่มประสิทธิภาพ

กำหนดฟังก์ชันการสูญเสียและเครื่องมือเพิ่มประสิทธิภาพสำหรับทั้งสองรุ่น

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

การสูญเสียการเลือกปฏิบัติ

วิธีนี้จะวัดว่าผู้แยกแยะสามารถแยกแยะภาพจริงจากของปลอมได้ดีเพียงใด มันเปรียบเทียบการคาดการณ์ของผู้แยกแยะเกี่ยวกับภาพจริงกับอาร์เรย์ 1 วินาทีและการคาดการณ์ของผู้แยกแยะเกี่ยวกับภาพปลอม (สร้างขึ้น) กับอาร์เรย์ของ 0

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

การสูญเสียเครื่องกำเนิดไฟฟ้า

การสูญเสียของเครื่องกำเนิดไฟฟ้าจะวัดว่ามันสามารถหลอกล่อผู้เลือกปฏิบัติได้ดีเพียงใด โดยสัญชาตญาณหากเครื่องกำเนิดไฟฟ้าทำงานได้ดีผู้แยกแยะจะจำแนกภาพปลอมว่าเป็นของจริง (หรือ 1) ที่นี่เปรียบเทียบการตัดสินใจของผู้แยกแยะเกี่ยวกับภาพที่สร้างขึ้นกับอาร์เรย์ 1s

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

ตัวเลือกและเครื่องมือเพิ่มประสิทธิภาพตัวสร้างนั้นแตกต่างกันเนื่องจากคุณจะฝึกสองเครือข่ายแยกกัน

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

บันทึกจุดตรวจ

สมุดบันทึกนี้ยังสาธิตวิธีการบันทึกและกู้คืนโมเดลซึ่งจะเป็นประโยชน์ในกรณีที่งานฝึกอบรมที่ใช้งานเป็นเวลานานถูกขัดจังหวะ

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

กำหนดลูปการฝึกอบรม

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

ลูปการฝึกเริ่มต้นด้วยเครื่องกำเนิดไฟฟ้าที่ได้รับเมล็ดพันธุ์แบบสุ่มเป็นอินพุต เมล็ดพันธุ์นั้นใช้ในการสร้างภาพ จากนั้นตัวเลือกจะใช้เพื่อจำแนกภาพจริง (วาดจากชุดฝึก) และภาพปลอม (ผลิตโดยเครื่องกำเนิดไฟฟ้า) การสูญเสียจะคำนวณสำหรับแต่ละรุ่นเหล่านี้และการไล่ระดับสีจะใช้เพื่ออัปเดตตัวสร้างและตัวเลือก

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

สร้างและบันทึกภาพ

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

ฝึกโมเดล

เรียกใช้วิธี train() กำหนดไว้ข้างต้นเพื่อฝึกเครื่องกำเนิดไฟฟ้าและตัวแยกแยะพร้อมกัน หมายเหตุการฝึก GAN อาจเป็นเรื่องยุ่งยาก สิ่งสำคัญคือผู้สร้างและผู้แบ่งแยกจะต้องไม่เอาชนะซึ่งกันและกัน (เช่นพวกเขาฝึกในอัตราที่ใกล้เคียงกัน)

ในช่วงเริ่มต้นของการฝึกอบรมภาพที่สร้างขึ้นจะดูเหมือนสัญญาณรบกวนแบบสุ่ม เมื่อการฝึกดำเนินไปตัวเลขที่สร้างขึ้นจะดูเป็นจริงมากขึ้น หลังจากผ่านไปประมาณ 50 ยุคพวกเขาจะมีลักษณะคล้ายกับตัวเลข MNIST อาจใช้เวลาประมาณหนึ่งนาที / ยุคด้วยการตั้งค่าเริ่มต้นใน Colab

train(train_dataset, EPOCHS)

png

คืนค่าด่านล่าสุด

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f371f792c88>

สร้าง GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

ใช้ imageio เพื่อสร้าง gif แบบเคลื่อนไหวโดยใช้ภาพที่บันทึกไว้ระหว่างการฝึก

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

ขั้นตอนถัดไป

บทช่วยสอนนี้ได้แสดงโค้ดทั้งหมดที่จำเป็นในการเขียนและฝึก GAN ในขั้นตอนต่อไปคุณอาจต้องการทดลองกับชุดข้อมูลอื่นตัวอย่างเช่นชุดข้อมูล Celeb Faces Attributes (CelebA) ขนาดใหญ่ที่ มีอยู่ใน Kaggle หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับ GAN โปรดดู บทช่วยสอน NIPS 2016: Generative Adversarial Networks