pix2pix: सशर्त GAN के साथ छवि-से-छवि अनुवाद

संग्रह की मदद से व्यवस्थित रहें अपनी प्राथमिकताओं के आधार पर, कॉन्टेंट को सेव करें और कैटगरी में बांटें.

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

यह ट्यूटोरियल दर्शाता है कि कैसे पिक्स2पिक्स नामक एक सशर्त जनरेटिव एडवरसैरियल नेटवर्क (सीजीएएन) का निर्माण और प्रशिक्षण किया जाता है, जो इनपुट इमेज से आउटपुट इमेज में मैपिंग सीखता है, जैसा कि आइसोला एट अल द्वारा सशर्त प्रतिकूल नेटवर्क के साथ इमेज-टू-इमेज अनुवाद में वर्णित है। (2017)। pix2pix एप्लिकेशन विशिष्ट नहीं है—इसे कार्यों की एक विस्तृत श्रृंखला पर लागू किया जा सकता है, जिसमें लेबल मानचित्रों से फ़ोटो को संश्लेषित करना, श्वेत और श्याम छवियों से रंगीन फ़ोटो बनाना, Google मानचित्र फ़ोटो को हवाई छवियों में बदलना और यहां तक ​​कि स्केच को फ़ोटो में बदलना शामिल है।

इस उदाहरण में, आपका नेटवर्क प्राग में चेक तकनीकी विश्वविद्यालय में सेंटर फॉर मशीन परसेप्शन द्वारा प्रदान किए गए सीएमपी फेकाडे डेटाबेस का उपयोग करके भवन के अग्रभाग की छवियां उत्पन्न करेगा। इसे संक्षिप्त रखने के लिए, आप इस डेटासेट की एक पूर्व-संसाधित प्रतिलिपि का उपयोग करेंगे जो कि pix2pix लेखकों द्वारा बनाई गई है।

Pix2pix cGAN में, आप इनपुट इमेज पर कंडीशन करते हैं और संबंधित आउटपुट इमेज जेनरेट करते हैं। cGAN को सबसे पहले कंडीशनल जनरेटिव एडवरसैरियल नेट (मिर्जा और ओसिन्दरो, 2014) में प्रस्तावित किया गया था।

आपके नेटवर्क की संरचना में निम्न शामिल होंगे:

  • यू-नेट- आधारित आर्किटेक्चर वाला जनरेटर।
  • एक विभेदक जो एक दृढ़ पैचगैन क्लासिफायरियर ( pix2pix पेपर में प्रस्तावित) द्वारा दर्शाया गया है।

ध्यान दें कि प्रत्येक युग एक V100 GPU पर लगभग 15 सेकंड का समय ले सकता है।

नीचे 200 युगों के लिए फ़ेडसेट डेटासेट (80k चरणों) पर प्रशिक्षण के बाद pix2pix cGAN द्वारा उत्पन्न आउटपुट के कुछ उदाहरण दिए गए हैं।

नमूना आउटपुट_1नमूना आउटपुट_2

TensorFlow और अन्य पुस्तकालयों को आयात करें

import tensorflow as tf

import os
import pathlib
import time
import datetime

from matplotlib import pyplot as plt
from IPython import display

डेटासेट लोड करें

CMP Facade डेटाबेस डेटा (30MB) डाउनलोड करें। अतिरिक्त डेटासेट यहां उसी प्रारूप में उपलब्ध हैं। Colab में आप ड्रॉप-डाउन मेनू से अन्य डेटासेट चुन सकते हैं। ध्यान दें कि कुछ अन्य डेटासेट काफी बड़े हैं ( edges2handbags 8GB है)।

dataset_name = "facades"
_URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'

path_to_zip = tf.keras.utils.get_file(
    fname=f"{dataset_name}.tar.gz",
    origin=_URL,
    extract=True)

path_to_zip  = pathlib.Path(path_to_zip)

PATH = path_to_zip.parent/dataset_name
Downloading data from http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
30171136/30168306 [==============================] - 19s 1us/step
30179328/30168306 [==============================] - 19s 1us/step
list(PATH.parent.iterdir())
[PosixPath('/home/kbuilder/.keras/datasets/facades.tar.gz'),
 PosixPath('/home/kbuilder/.keras/datasets/YellowLabradorLooking_new.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/facades'),
 PosixPath('/home/kbuilder/.keras/datasets/mnist.npz')]

प्रत्येक मूल छवि का आकार 256 x 512 है जिसमें दो 256 x 256 चित्र हैं:

sample_image = tf.io.read_file(str(PATH / 'train/1.jpg'))
sample_image = tf.io.decode_jpeg(sample_image)
print(sample_image.shape)
(256, 512, 3)
plt.figure()
plt.imshow(sample_image)
<matplotlib.image.AxesImage at 0x7f35a3653c90>

पीएनजी

आपको वास्तविक भवन के अग्रभाग की छवियों को आर्किटेक्चर लेबल छवियों से अलग करने की आवश्यकता है—जिनमें से सभी का आकार 256 x 256 होगा।

एक फ़ंक्शन को परिभाषित करें जो छवि फ़ाइलों को लोड करता है और दो छवि टेंसर आउटपुट करता है:

def load(image_file):
  # Read and decode an image file to a uint8 tensor
  image = tf.io.read_file(image_file)
  image = tf.io.decode_jpeg(image)

  # Split each image tensor into two tensors:
  # - one with a real building facade image
  # - one with an architecture label image 
  w = tf.shape(image)[1]
  w = w // 2
  input_image = image[:, w:, :]
  real_image = image[:, :w, :]

  # Convert both images to float32 tensors
  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image

इनपुट (वास्तुकला लेबल छवि) और वास्तविक (भवन मुखौटा फोटो) छवियों का एक नमूना प्लॉट करें:

inp, re = load(str(PATH / 'train/100.jpg'))
# Casting to int for matplotlib to display the images
plt.figure()
plt.imshow(inp / 255.0)
plt.figure()
plt.imshow(re / 255.0)
<matplotlib.image.AxesImage at 0x7f35981a4910>

पीएनजी

पीएनजी

जैसा कि pix2pix पेपर में वर्णित है, आपको प्रशिक्षण सेट को प्रीप्रोसेस करने के लिए रैंडम जिटरिंग और मिररिंग लागू करने की आवश्यकता है।

कई कार्यों को परिभाषित करें जो:

  1. प्रत्येक 256 x 256 छवि को एक बड़ी ऊंचाई और चौड़ाई में आकार दें- 286 x 286
  2. इसे यादृच्छिक रूप से वापस 256 x 256 पर क्रॉप करें।
  3. छवि को क्षैतिज रूप से फ़्लिप करें अर्थात बाएं से दाएं (यादृच्छिक मिररिंग)।
  4. छवियों को [-1, 1] श्रेणी में सामान्य करें।
# The facade training set consist of 400 images
BUFFER_SIZE = 400
# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 1
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
  # Resizing to 286x286
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # Random cropping back to 256x256
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # Random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

आप कुछ प्रीप्रोसेस्ड आउटपुट का निरीक्षण कर सकते हैं:

plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i + 1)
  plt.imshow(rj_inp / 255.0)
  plt.axis('off')
plt.show()

पीएनजी

यह जाँचने के बाद कि लोडिंग और प्रीप्रोसेसिंग काम करता है, आइए कुछ सहायक कार्यों को परिभाषित करें जो प्रशिक्षण और परीक्षण सेट को लोड और प्रीप्रोसेस करते हैं:

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

tf.data के साथ एक इनपुट पाइपलाइन बनाएं

train_dataset = tf.data.Dataset.list_files(str(PATH / 'train/*.jpg'))
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
try:
  test_dataset = tf.data.Dataset.list_files(str(PATH / 'test/*.jpg'))
except tf.errors.InvalidArgumentError:
  test_dataset = tf.data.Dataset.list_files(str(PATH / 'val/*.jpg'))
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

जनरेटर का निर्माण

आपके pix2pix cGAN का जनरेटर एक संशोधित U-Net है। यू-नेट में एक एनकोडर (डाउनसैंपलर) और डिकोडर (अपसैंपलर) होते हैं। (आप इसके बारे में इमेज सेगमेंटेशन ट्यूटोरियल और यू-नेट प्रोजेक्ट वेबसाइट पर अधिक जानकारी प्राप्त कर सकते हैं।)

  • एन्कोडर में प्रत्येक ब्लॉक है: कनवल्शन -> बैच सामान्यीकरण -> लीकी ReLU
  • डिकोडर में प्रत्येक ब्लॉक है: ट्रांसपोज़्ड कनवल्शन -> बैच नॉर्मलाइज़ेशन -> ड्रॉपआउट (पहले 3 ब्लॉक्स पर लागू) -> ReLU
  • एन्कोडर और डिकोडर के बीच स्किप कनेक्शन हैं (जैसा कि यू-नेट में है)।

डाउनसैंपलर (एनकोडर) को परिभाषित करें:

OUTPUT_CHANNELS = 3
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
(1, 128, 128, 3)

अपसैंपलर (डिकोडर) को परिभाषित करें:

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)
(1, 256, 256, 3)

डाउनसैंपलर और अपसैंपलर के साथ जनरेटर को परिभाषित करें:

def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

जनरेटर मॉडल आर्किटेक्चर की कल्पना करें:

generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

पीएनजी

जनरेटर का परीक्षण करें:

gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f35cfd20610>

पीएनजी

जनरेटर हानि को परिभाषित करें

GAN एक नुकसान सीखते हैं जो डेटा के अनुकूल होता है, जबकि cGAN एक संरचित नुकसान सीखते हैं जो एक संभावित संरचना को दंडित करता है जो नेटवर्क आउटपुट और लक्ष्य छवि से भिन्न होता है, जैसा कि pix2pix पेपर में वर्णित है।

  • जेनरेटर लॉस उत्पन्न छवियों का एक सिग्मॉइड क्रॉस-एन्ट्रॉपी नुकसान है और लोगों की एक सरणी है
  • Pix2pix पेपर में L1 हानि का भी उल्लेख है, जो उत्पन्न छवि और लक्ष्य छवि के बीच एक MAE (मतलब पूर्ण त्रुटि) है।
  • यह उत्पन्न छवि को लक्ष्य छवि के समान संरचनात्मक रूप से बनने की अनुमति देता है।
  • कुल जनरेटर हानि की गणना करने का सूत्र gan_loss + LAMBDA * l1_loss है, जहाँ LAMBDA = 100 है। यह मूल्य कागज के लेखकों द्वारा तय किया गया था।
LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # Mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

जनरेटर के लिए प्रशिक्षण प्रक्रिया इस प्रकार है:

जेनरेटर अद्यतन छवि

विभेदक का निर्माण करें

Pix2pix cGAN में विवेचक एक दृढ़ पैचगैन क्लासिफायरियर है - यह वर्गीकृत करने का प्रयास करता है कि प्रत्येक छवि पैच वास्तविक है या नहीं, जैसा कि pix2pix पेपर में वर्णित है।

  • विवेचक में प्रत्येक ब्लॉक है: कनवल्शन -> बैच नॉर्मलाइज़ेशन -> लीक रेएलयू।
  • अंतिम परत के बाद आउटपुट का आकार (batch_size, 30, 30, 1) है।
  • आउटपुट का प्रत्येक 30 x 30 छवि पैच इनपुट छवि के 70 x 70 भाग को वर्गीकृत करता है।
  • विवेचक को 2 इनपुट प्राप्त होते हैं:
    • इनपुट छवि और लक्ष्य छवि, जिसे इसे वास्तविक के रूप में वर्गीकृत करना चाहिए।
    • इनपुट छवि और उत्पन्न छवि (जनरेटर का आउटपुट), जिसे इसे नकली के रूप में वर्गीकृत करना चाहिए।
    • इन 2 इनपुट को एक साथ जोड़ने के लिए tf.concat([inp, tar], axis=-1) का उपयोग करें।

आइए विवेचक को परिभाषित करें:

def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

विभेदक मॉडल वास्तुकला की कल्पना करें:

discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

पीएनजी

विभेदक का परीक्षण करें:

disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f35cec82c50>

पीएनजी

विभेदक हानि को परिभाषित करें

  • discriminator_loss फ़ंक्शन 2 इनपुट लेता है: वास्तविक छवियां और उत्पन्न छवियां
  • real_loss वास्तविक छवियों का एक सिग्मॉइड क्रॉस-एन्ट्रॉपी नुकसान है और लोगों की एक सरणी है (क्योंकि ये वास्तविक छवियां हैं)
  • जेनरेट_लॉस generated_loss छवियों का एक सिग्मॉइड क्रॉस-एन्ट्रॉपी नुकसान और शून्य की एक सरणी है (क्योंकि ये नकली छवियां हैं)
  • total_loss real_loss और generated_loss का योग है।
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

विवेचक के लिए प्रशिक्षण प्रक्रिया नीचे दिखाई गई है।

आर्किटेक्चर और हाइपरपैरामीटर के बारे में अधिक जानने के लिए आप pix2pix पेपर का संदर्भ ले सकते हैं।

विभेदक अद्यतन छवि

ऑप्टिमाइज़र और चेकपॉइंट-सेवर को परिभाषित करें

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

छवियां उत्पन्न करें

प्रशिक्षण के दौरान कुछ छवियों को प्लॉट करने के लिए एक फ़ंक्शन लिखें।

  • परीक्षण सेट से छवियों को जनरेटर में पास करें।
  • जनरेटर तब इनपुट इमेज को आउटपुट में ट्रांसलेट करेगा।
  • अंतिम चरण भविष्यवाणियों और वॉइला को प्लॉट करना है!
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

फ़ंक्शन का परीक्षण करें:

for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)

पीएनजी

प्रशिक्षण

  • प्रत्येक उदाहरण के लिए इनपुट एक आउटपुट उत्पन्न करता है।
  • विवेचक को पहले इनपुट के रूप में input_image और उत्पन्न छवि प्राप्त होती है। दूसरा इनपुट input_image और target_image है।
  • इसके बाद, जनरेटर और विवेचक हानि की गणना करें।
  • फिर, जनरेटर और विवेचक चर (इनपुट) दोनों के संबंध में नुकसान के ग्रेडिएंट की गणना करें और उन्हें ऑप्टिमाइज़र पर लागू करें।
  • अंत में, नुकसान को TensorBoard में दर्ज करें।
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
    tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

वास्तविक प्रशिक्षण पाश। चूंकि यह ट्यूटोरियल एक से अधिक डेटासेट पर चल सकता है, और डेटासेट आकार में बहुत भिन्न होते हैं, इसलिए प्रशिक्षण लूप को युगों के बजाय चरणों में काम करने के लिए सेटअप किया जाता है।

  • चरणों की संख्या पर पुनरावृति।
  • हर 10 कदम पर एक बिंदु ( . ) प्रिंट करें।
  • प्रत्येक 1k चरण: प्रदर्शन साफ़ करें और प्रगति दिखाने के लिए generate_images चलाएँ।
  • हर 5k कदम: एक चौकी बचाओ।
def fit(train_ds, test_ds, steps):
  example_input, example_target = next(iter(test_ds.take(1)))
  start = time.time()

  for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
    if (step) % 1000 == 0:
      display.clear_output(wait=True)

      if step != 0:
        print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')

      start = time.time()

      generate_images(generator, example_input, example_target)
      print(f"Step: {step//1000}k")

    train_step(input_image, target, step)

    # Training step
    if (step+1) % 10 == 0:
      print('.', end='', flush=True)


    # Save (checkpoint) the model every 5k steps
    if (step + 1) % 5000 == 0:
      checkpoint.save(file_prefix=checkpoint_prefix)

यह प्रशिक्षण लूप उन लॉग को सहेजता है जिन्हें आप प्रशिक्षण प्रगति की निगरानी के लिए TensorBoard में देख सकते हैं।

यदि आप स्थानीय मशीन पर काम करते हैं, तो आप एक अलग TensorBoard प्रक्रिया शुरू करेंगे। नोटबुक में काम करते समय, TensorBoard के साथ निगरानी के लिए प्रशिक्षण शुरू करने से पहले व्यूअर को लॉन्च करें।

व्यूअर को लॉन्च करने के लिए निम्नलिखित को कोड-सेल में पेस्ट करें:

%load_ext tensorboard
%tensorboard --logdir {log_dir}

अंत में, प्रशिक्षण लूप चलाएँ:

fit(train_dataset, test_dataset, steps=40000)
Time taken for 1000 steps: 36.53 sec

पीएनजी

Step: 39k
....................................................................................................

यदि आप TensorBoard परिणामों को सार्वजनिक रूप से साझा करना चाहते हैं, तो आप निम्न को कोड-सेल में कॉपी करके TensorBoard.dev पर लॉग अपलोड कर सकते हैं।

tensorboard dev upload --logdir {log_dir}

आप इस नोटबुक के पिछले रन के परिणाम TensorBoard.dev पर देख सकते हैं।

TensorBoard.dev सभी के साथ ML प्रयोगों को होस्ट करने, ट्रैक करने और साझा करने का एक प्रबंधित अनुभव है।

इसमें <iframe> का उपयोग करके इनलाइन भी शामिल किया जा सकता है:

display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")

एक साधारण वर्गीकरण या प्रतिगमन मॉडल की तुलना में एक GAN (या एक cGAN जैसे pix2pix) को प्रशिक्षित करते समय लॉग की व्याख्या करना अधिक सूक्ष्म होता है। देखने के लिए चीजें:

  • जांचें कि न तो जनरेटर और न ही विवेचक मॉडल "जीता" है। यदि या तो gen_gan_loss या disc_loss बहुत कम हो जाता है, तो यह एक संकेतक है कि यह मॉडल दूसरे पर हावी है, और आप संयुक्त मॉडल को सफलतापूर्वक प्रशिक्षण नहीं दे रहे हैं।
  • मान log(2) = 0.69 इन नुकसानों के लिए एक अच्छा संदर्भ बिंदु है, क्योंकि यह 2 की जटिलता को इंगित करता है - विवेचक, औसतन, दो विकल्पों के बारे में समान रूप से अनिश्चित है।
  • disc_loss के लिए, 0.69 से नीचे के मान का मतलब है कि विवेचक वास्तविक और उत्पन्न छवियों के संयुक्त सेट पर यादृच्छिक से बेहतर कर रहा है।
  • gen_gan_loss के लिए, 0.69 से नीचे के मान का अर्थ है कि जनरेटर विवेचक को बेवकूफ बनाने में यादृच्छिक से बेहतर कर रहा है।
  • जैसे-जैसे प्रशिक्षण आगे बढ़ता है, gen_l1_loss होना चाहिए।

नवीनतम चेकपॉइंट को पुनर्स्थापित करें और नेटवर्क का परीक्षण करें

ls {checkpoint_dir}
checkpoint          ckpt-5.data-00000-of-00001
ckpt-1.data-00000-of-00001  ckpt-5.index
ckpt-1.index            ckpt-6.data-00000-of-00001
ckpt-2.data-00000-of-00001  ckpt-6.index
ckpt-2.index            ckpt-7.data-00000-of-00001
ckpt-3.data-00000-of-00001  ckpt-7.index
ckpt-3.index            ckpt-8.data-00000-of-00001
ckpt-4.data-00000-of-00001  ckpt-8.index
ckpt-4.index
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f35cfd6b8d0>

परीक्षण सेट का उपयोग करके कुछ छवियां बनाएं

# Run the trained model on a few examples from the test set
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)

पीएनजी

पीएनजी

पीएनजी

पीएनजी

पीएनजी