Sehen Sie sich Keynotes, Produktsitzungen, Workshops und mehr in Google I / O an. Siehe Wiedergabeliste

Pix2Pix

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Dieses Notizbuch demonstriert die Bild-zu-Bild-Übersetzung unter Verwendung von bedingten GANs, wie unter Bild-zu-Bild-Übersetzung mit bedingten kontradiktorischen Netzwerken beschrieben . Mit dieser Technik können Sie Schwarzweißfotos einfärben, Google Maps in Google Earth konvertieren usw. Hier können Sie Gebäudefassaden in echte Gebäude konvertieren.

Sie verwenden beispielsweise die CMP-Fassadendatenbank , die vom Zentrum für Maschinenwahrnehmung der Tschechischen Technischen Universität in Prag bereitgestellt wird. Um das Beispiel kurz zu halten, verwenden Sie eine vorverarbeiteten Kopie des Datensatzes, die von den Autoren des erzeugten Papiers oben.

Jede Epoche dauert auf einer einzelnen V100-GPU etwa 15 Sekunden.

Nachfolgend finden Sie die Ausgabe, die nach dem Training des Modells für 200 Epochen generiert wurde.

Beispielausgabe_1Beispielausgabe_2

Importieren Sie TensorFlow und andere Bibliotheken

import tensorflow as tf

import os
import time

from matplotlib import pyplot as plt
from IPython import display
pip install -q -U tensorboard

Laden Sie den Datensatz

Sie können diesen Datensatz und ähnliche Datensätze hier herunterladen. Wie in dem genannten Papier gelten zufällige Flackern und Spiegelung auf den Trainingsdaten.

  • Bei zufälligem Jitter wird die Bildgröße auf 286 x 286 und dann zufällig auf 256 x 256
  • Bei der zufälligen Spiegelung wird das Bild zufällig horizontal gespiegelt, dh von links nach rechts.
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
Downloading data from https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz
30171136/30168306 [==============================] - 10s 0us/step
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image
inp, re = load(PATH+'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)
<matplotlib.image.AxesImage at 0x7f194823a7d0>

png

png

def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]
# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

Wie Sie in den Bildern unten sehen können, durchlaufen sie zufälliges Jittering Zufälliges Jittering, wie in der Veröffentlichung beschrieben

  1. Ändern Sie die Größe eines Bildes auf eine größere Höhe und Breite
  2. Nach dem Zufallsprinzip auf die Zielgröße zuschneiden
  3. Drehen Sie das Bild nach dem Zufallsprinzip horizontal
plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.axis('off')
plt.show()

png

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

Eingabe-Pipeline

train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

Erstellen Sie den Generator

  • Die Architektur des Generators ist ein modifiziertes U-Net.
  • Jeder Block im Encoder ist (Conv -> Batchnorm -> Leaky ReLU)
  • Jeder Block im Decoder ist (Transposed Conv -> Batchnorm -> Dropout (angewendet auf die ersten 3 Blöcke) -> ReLU)
  • Es gibt Sprungverbindungen zwischen dem Codierer und dem Decodierer (wie in U-Net).
OUTPUT_CHANNELS = 3
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
(1, 128, 128, 3)
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)
(1, 256, 256, 3)
def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
    downsample(128, 4),  # (bs, 64, 64, 128)
    downsample(256, 4),  # (bs, 32, 32, 256)
    downsample(512, 4),  # (bs, 16, 16, 512)
    downsample(512, 4),  # (bs, 8, 8, 512)
    downsample(512, 4),  # (bs, 4, 4, 512)
    downsample(512, 4),  # (bs, 2, 2, 512)
    downsample(512, 4),  # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
    upsample(512, 4),  # (bs, 16, 16, 1024)
    upsample(256, 4),  # (bs, 32, 32, 512)
    upsample(128, 4),  # (bs, 64, 64, 256)
    upsample(64, 4),  # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

png

gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f189c184e10>

png

  • Generatorverlust
    • Es ist ein Sigmoid-Kreuzentropieverlust der erzeugten Bilder und einer Reihe von Bildern.
    • Das Papier enthält auch einen L1-Verlust, der MAE (mittlerer absoluter Fehler) zwischen dem erzeugten Bild und dem Zielbild ist.
    • Dadurch kann das erzeugte Bild dem Zielbild strukturell ähnlich werden.
    • Die Formel zur Berechnung des Gesamtgeneratorverlusts = gan_loss + LAMBDA * l1_loss, wobei LAMBDA = 100. Dieser Wert wurde von den Autoren des Papiers festgelegt .

Das Trainingsverfahren für den Generator ist unten dargestellt:

LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

Generator Update Image

Bauen Sie den Diskriminator

  • Der Diskriminator ist ein PatchGAN.
  • Jeder Block im Diskriminator ist (Conv -> BatchNorm -> Leaky ReLU)
  • Die Form der Ausgabe nach der letzten Ebene lautet (batch_size, 30, 30, 1).
  • Jeder 30x30-Patch der Ausgabe klassifiziert einen 70x70-Teil des Eingabebilds (eine solche Architektur wird als PatchGAN bezeichnet).
  • Der Diskriminator erhält 2 Eingänge.
    • Eingabebild und Zielbild, das als real klassifiziert werden soll.
    • Eingabebild und generiertes Bild (Ausgabe des Generators), das als Fälschung eingestuft werden soll.
    • Verketten Sie diese beiden Eingaben im Code ( tf.concat([inp, tar], axis=-1) ).
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

png

disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f183c3cd550>

png

Diskriminatorverlust

  • Die Diskriminatorverlustfunktion benötigt 2 Eingänge; echte Bilder, erzeugte Bilder
  • real_loss ist ein Sigmoid-Kreuzentropieverlust der realen Bilder und einer Reihe von solchen (da dies die realen Bilder sind)
  • generate_loss ist ein Sigmoid-Kreuzentropieverlust der erzeugten Bilder und ein Array von Nullen (da dies die gefälschten Bilder sind).
  • Dann ist der Totalverlust die Summe aus Realverlust und dem generierten Verlust
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

Das Trainingsverfahren für den Diskriminator ist unten gezeigt.

Um mehr über die Architektur und die Hyperparameter zu erfahren, lesen Sie das Dokument .

Diskriminator-Update-Image

Definieren Sie die Optimierer und den Checkpoint-Saver

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Bilder generieren

Schreiben Sie eine Funktion, um einige Bilder während des Trainings zu zeichnen.

  • Übergeben Sie Bilder aus dem Testdatensatz an den Generator.
  • Der Generator übersetzt dann das Eingabebild in die Ausgabe.
  • Der letzte Schritt besteht darin, die Vorhersagen und die Voila zu zeichnen!
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)

png

Ausbildung

  • Für jede Beispieleingabe wird eine Ausgabe generiert.
  • Der Diskriminator empfängt das Eingabebild und das erzeugte Bild als erste Eingabe. Die zweite Eingabe ist das Eingabebild und das Zielbild.
  • Berechnen Sie als nächstes den Generator- und den Diskriminatorverlust.
  • Berechnen Sie dann die Verlustgradienten sowohl für die Generator- als auch für die Diskriminatorvariablen (Eingänge) und wenden Sie diese auf den Optimierer an.
  • Protokollieren Sie dann die Verluste in TensorBoard.
EPOCHS = 150
import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)

Die eigentliche Trainingsschleife:

  • Iteriert über die Anzahl der Epochen.
  • In jeder Epoche wird die Anzeige gelöscht und es werden generate_images , um den Fortschritt anzuzeigen.
  • In jeder Epoche wird der Trainingsdatensatz durchlaufen und ein '.' für jedes Beispiel.
  • Alle 20 Epochen wird ein Kontrollpunkt gespeichert.
def fit(train_ds, epochs, test_ds):
  for epoch in range(epochs):
    start = time.time()

    display.clear_output(wait=True)

    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    print("Epoch: ", epoch)

    # Train
    for n, (input_image, target) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      train_step(input_image, target, epoch)
    print()

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix=checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix=checkpoint_prefix)

Diese Trainingsschleife speichert Protokolle, die Sie einfach in TensorBoard anzeigen können, um den Trainingsfortschritt zu überwachen. Wenn Sie lokal arbeiten, starten Sie einen separaten Tensorboard-Prozess. Wenn Sie in einem Notebook mit TensorBoard überwachen möchten, ist es am einfachsten, den Viewer vor Beginn des Trainings zu starten.

Um den Viewer zu starten, fügen Sie Folgendes in eine Codezelle ein:

%load_ext tensorboard
%tensorboard --logdir {log_dir}

Führen Sie nun die Trainingsschleife aus:

fit(train_dataset, EPOCHS, test_dataset)

png

Epoch:  149
....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................

Time taken for epoch 150 is 15.598696947097778 sec

Wenn Sie die TensorBoard-Ergebnisse öffentlich teilen möchten, können Sie die Protokolle auf TensorBoard.dev hochladen, indem Sie Folgendes in eine Codezelle kopieren.

tensorboard dev upload --logdir  {log_dir}

Sie können die Ergebnisse einer früheren Ausführung dieses Notizbuchs auf TensorBoard.dev anzeigen .

TensorBoard.dev ist eine verwaltete Erfahrung zum Hosten, Verfolgen und Teilen von ML-Experimenten mit allen.

Es kann auch inline mit einem <iframe> :

display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")

Das Interpretieren der Protokolle aus einer GAN ist subtiler als ein einfaches Klassifizierungs- oder Regressionsmodell. Dinge zu suchen ::

  • Überprüfen Sie, ob keines der Modelle "gewonnen" hat. Wenn entweder der gen_gan_loss oder der disc_loss sehr niedrig wird, ist dies ein Indikator dafür, dass dieses Modell das andere dominiert und Sie das kombinierte Modell nicht erfolgreich trainieren.
  • Der Wert log(2) = 0.69 ist ein guter Bezugspunkt für diese Verluste, da er eine Ratlosigkeit von 2 anzeigt: Dass der Diskriminator im Durchschnitt gleichermaßen unsicher über die beiden Optionen ist.
  • Für den disc_loss ein Wert unter 0.69 dass der Diskriminator bei der kombinierten Menge von real + generierten Bildern besser als zufällig disc_loss .
  • Für den gen_gan_loss ein Wert unter 0.69 dass der Generator den Diskriminator besser als zufällig täuscht.
  • Mit fortschreitendem gen_l1_loss sollte der gen_l1_loss sinken.

Stellen Sie den neuesten Prüfpunkt wieder her und testen Sie ihn

ls {checkpoint_dir}
checkpoint          ckpt-5.data-00000-of-00001
ckpt-1.data-00000-of-00001  ckpt-5.index
ckpt-1.index            ckpt-6.data-00000-of-00001
ckpt-2.data-00000-of-00001  ckpt-6.index
ckpt-2.index            ckpt-7.data-00000-of-00001
ckpt-3.data-00000-of-00001  ckpt-7.index
ckpt-3.index            ckpt-8.data-00000-of-00001
ckpt-4.data-00000-of-00001  ckpt-8.index
ckpt-4.index
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f134de35a90>

Generieren Sie mit dem Testdatensatz

# Run the trained model on a few examples from the test dataset
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)

png

png

png

png

png