Cette page a été traduite par l'API Cloud Translation.
Switch to English

Pix2Pix

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le cahier

Ce cahier montre la traduction d'image en image à l'aide de GAN conditionnels, comme décrit dans Traduction d'image à image avec des réseaux contradictoires conditionnels . En utilisant cette technique, nous pouvons coloriser des photos en noir et blanc, convertir des cartes google en google earth, etc. Ici, nous convertissons les façades de bâtiments en bâtiments réels.

Par exemple, nous utiliserons la base de données des façades CMP , fournie par le Center for Machine Perception de l' Université technique tchèque de Prague . Pour garder notre exemple court, nous utiliserons une copie prétraitée de cet ensemble de données, créée par les auteurs de l' article ci-dessus.

Chaque époque prend environ 15 secondes sur un seul GPU V100.

Vous trouverez ci-dessous la sortie générée après l'entraînement du modèle pour 200 époques.

exemple de sortie_1échantillon output_2

Importer TensorFlow et d'autres bibliothèques

 import tensorflow as tf

import os
import time

from matplotlib import pyplot as plt
from IPython import display
 
pip install -q -U tensorboard
ERROR: tensorflow 2.2.0 has requirement tensorboard<2.3.0,>=2.2.0, but you'll have tensorboard 2.3.0 which is incompatible.

Charger le jeu de données

Vous pouvez télécharger cet ensemble de données et des ensembles de données similaires à partir d' ici . Comme mentionné dans l' article, nous appliquons une gigue et une mise en miroir aléatoires à l'ensemble de données d'entraînement.

  • En cas de gigue aléatoire, l'image est redimensionnée à 286 x 286 , puis recadrée au hasard à 256 x 256
  • Dans la mise en miroir aléatoire, l'image est retournée de manière aléatoire horizontalement, c'est-à-dire de gauche à droite.
 _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
 
Downloading data from https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz
30171136/30168306 [==============================] - 2s 0us/step

 BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
 
 def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image
 
 inp, re = load(PATH+'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)
 
<matplotlib.image.AxesImage at 0x7f9b575231d0>

png

png

 def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image
 
 def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]
 
 # normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image
 
 @tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image
 

Comme vous pouvez le voir dans les images ci-dessous, ils subissent une gigue aléatoire La gigue aléatoire comme décrit dans l'article est de

  1. Redimensionner une image à une plus grande hauteur et largeur
  2. Recadrer au hasard à la taille cible
  3. Retourner l'image horizontalement au hasard
 plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.axis('off')
plt.show()
 

png

 def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
 
 def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
 

Pipeline d'entrée

 train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
 
 test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)
 

Construisez le générateur

  • L'architecture du générateur est un U-Net modifié.
  • Chaque bloc de l'encodeur est (Conv -> Batchnorm -> Leaky ReLU)
  • Chaque bloc du décodeur est (Transposed Conv -> Batchnorm -> Dropout (appliqué aux 3 premiers blocs) -> ReLU)
  • Il existe des connexions de saut entre le codeur et le décodeur (comme dans U-Net).
 OUTPUT_CHANNELS = 3
 
 def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result
 
 down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
 
(1, 128, 128, 3)

 def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result
 
 up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)
 
(1, 256, 256, 3)

 def Generator():
  inputs = tf.keras.layers.Input(shape=[256,256,3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128, 4), # (bs, 64, 64, 128)
    downsample(256, 4), # (bs, 32, 32, 256)
    downsample(512, 4), # (bs, 16, 16, 512)
    downsample(512, 4), # (bs, 8, 8, 512)
    downsample(512, 4), # (bs, 4, 4, 512)
    downsample(512, 4), # (bs, 2, 2, 512)
    downsample(512, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, 4), # (bs, 16, 16, 1024)
    upsample(256, 4), # (bs, 32, 32, 512)
    upsample(128, 4), # (bs, 64, 64, 256)
    upsample(64, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)
 
 generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
 

png

 gen_output = generator(inp[tf.newaxis,...], training=False)
plt.imshow(gen_output[0,...])
 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

<matplotlib.image.AxesImage at 0x7f9adf885940>

png

  • Perte du générateur
    • Il s'agit d'une perte d'entropie croisée sigmoïde des images générées et d'un tableau de celles-ci .
    • Le papier comprend également une perte L1 qui est MAE (erreur absolue moyenne) entre l'image générée et l'image cible.
    • Cela permet à l'image générée de devenir structurellement similaire à l'image cible.
    • La formule pour calculer la perte totale du générateur = gan_loss + LAMBDA * l1_loss, où LAMBDA = 100. Cette valeur a été décidée par les auteurs de l' article .

La procédure de formation pour le générateur est illustrée ci-dessous:

 LAMBDA = 100
 
 def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss
 

Image de mise à jour du générateur

Construire le discriminateur

  • Le discriminateur est un PatchGAN.
  • Chaque bloc dans le discriminateur est (Conv -> BatchNorm -> Leaky ReLU)
  • La forme de la sortie après la dernière couche est (batch_size, 30, 30, 1)
  • Chaque patch 30x30 de la sortie classe une partie 70x70 de l'image d'entrée (une telle architecture est appelée PatchGAN).
  • Le discriminateur reçoit 2 entrées.
    • Image d'entrée et image cible, qu'elle doit classer comme réelle.
    • Image d'entrée et image générée (sortie du générateur), qu'elle doit classer comme fausse.
    • Nous concaténons ces 2 entrées dans le code ( tf.concat([inp, tar], axis=-1) )
 def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)
 
 discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)
 

png

 disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
 
<matplotlib.colorbar.Colorbar at 0x7f9a6c17bcc0>

png

Perte du discriminateur

  • La fonction de perte du discriminateur prend 2 entrées; images réelles, images générées
  • real_loss est une perte d'entropie croisée sigmoïde des images réelles et un tableau de celles-ci (puisque ce sont les images réelles)
  • generated_loss est une perte d'entropie croisée sigmoïde des images générées et un tableau de zéros (puisque ce sont les fausses images)
  • Alors le total_loss est la somme de real_loss et du generated_loss
 loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
 
 def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss
 

La procédure d'apprentissage du discriminateur est illustrée ci-dessous.

Pour en savoir plus sur l'architecture et les hyperparamètres, vous pouvez consulter l' article .

Image de mise à jour du discriminateur

Définir les optimiseurs et l'économiseur de points de contrôle

 generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
 
 checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
 

Générer des images

Ecrivez une fonction pour tracer des images pendant la formation.

  • Nous transmettons les images du jeu de données de test au générateur.
  • Le générateur traduira alors l'image d'entrée en sortie.
  • La dernière étape consiste à tracer les prédictions et le tour est joué!
 def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
 
 for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)
 

png

Formation

  • Pour chaque exemple d'entrée, générez une sortie.
  • Le discriminateur reçoit l'image d'entrée et l'image générée comme première entrée. La deuxième entrée est input_image et target_image.
  • Ensuite, nous calculons le générateur et la perte du discriminateur.
  • Ensuite, nous calculons les gradients de perte par rapport au générateur et aux variables du discriminateur (entrées) et les appliquons à l'optimiseur.
  • Enregistrez ensuite les pertes dans TensorBoard.
 EPOCHS = 150
 
 import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
 
 @tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)
 

La boucle d'entraînement réelle:

  • Itère sur le nombre d'époques.
  • À chaque époque, il efface l'affichage et exécute generate_images pour montrer sa progression.
  • À chaque époque, il itère sur l'ensemble de données d'entraînement, imprimant un «.» pour chaque exemple.
  • Il enregistre un point de contrôle toutes les 20 époques.
 def fit(train_ds, epochs, test_ds):
  for epoch in range(epochs):
    start = time.time()

    display.clear_output(wait=True)

    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    print("Epoch: ", epoch)

    # Train
    for n, (input_image, target) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      train_step(input_image, target, epoch)
    print()

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix = checkpoint_prefix)
 

Cette boucle d'entraînement enregistre les journaux que vous pouvez facilement afficher dans TensorBoard pour surveiller la progression de l'entraînement. En travaillant localement, vous lanceriez un processus de tensorboard séparé. Dans un ordinateur portable, si vous souhaitez surveiller avec TensorBoard, il est plus simple de lancer le visualiseur avant de commencer la formation.

Pour lancer la visionneuse, collez ce qui suit dans une cellule de code:

 %load_ext tensorboard
%tensorboard --logdir {log_dir}
 

Exécutez maintenant la boucle d'entraînement:

 fit(train_dataset, EPOCHS, test_dataset)
 

png

Epoch:  125
....................................................................................................
.......................................................................................

Si vous souhaitez partager publiquement les résultats de TensorBoard, vous pouvez télécharger les journaux sur TensorBoard.dev en copiant ce qui suit dans une cellule de code.

tensorboard dev upload --logdir  {log_dir}

Vous pouvez afficher les résultats d'une précédente exécution de ce notebook sur TensorBoard.dev .

TensorBoard.dev est une expérience gérée pour l'hébergement, le suivi et le partage d'expériences ML avec tout le monde.

Il peut également être inclus en ligne à l'aide d'un <iframe> :

 display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")
 

L'interprétation des logs à partir d'un GAN est plus subtile qu'un simple modèle de classification ou de régression. Choses à rechercher:

  • Vérifiez qu'aucun des modèles n'a «gagné». Si le gen_gan_loss ou le disc_loss devient très bas, c'est un indicateur que ce modèle domine l'autre et que vous disc_loss pas avec succès le modèle combiné.
  • La valeur log(2) = 0.69 est un bon point de référence pour ces pertes, car elle indique une perplexité de 2: que le discriminateur est en moyenne également incertain sur les deux options.
  • Pour le disc_loss une valeur inférieure à 0.69 signifie que le discriminateur fait mieux que l'aléatoire, sur l'ensemble combiné d'images réelles + générées.
  • Pour le gen_gan_loss une valeur inférieure à 0.69 signifie que le générateur i fait mieux que le hasard pour déjouer le descriminateur.
  • Au fur et à mesure que l'entraînement progresse, la gen_l1_loss devrait diminuer.

Restaurez le dernier point de contrôle et testez

ls {checkpoint_dir}
checkpoint          ckpt-5.data-00000-of-00002
ckpt-1.data-00000-of-00002  ckpt-5.data-00001-of-00002
ckpt-1.data-00001-of-00002  ckpt-5.index
ckpt-1.index            ckpt-6.data-00000-of-00002
ckpt-2.data-00000-of-00002  ckpt-6.data-00001-of-00002
ckpt-2.data-00001-of-00002  ckpt-6.index
ckpt-2.index            ckpt-7.data-00000-of-00002
ckpt-3.data-00000-of-00002  ckpt-7.data-00001-of-00002
ckpt-3.data-00001-of-00002  ckpt-7.index
ckpt-3.index            ckpt-8.data-00000-of-00002
ckpt-4.data-00000-of-00002  ckpt-8.data-00001-of-00002
ckpt-4.data-00001-of-00002  ckpt-8.index
ckpt-4.index

 # restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
 
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9a6c01e978>

Générer à l'aide d'un jeu de données de test

 # Run the trained model on a few examples from the test dataset
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)
 

png

png

png

png

png