Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Pix2Pix

Ver en TensorFlow.org Ejecutar en Google Colab Ver código fuente en GitHub Descargar cuaderno

Este cuaderno muestra la traducción de imagen a imagen utilizando GAN condicionales, como se describe en Traducción de imagen a imagen con redes adversas condicionales . Con esta técnica, podemos colorear fotos en blanco y negro, convertir mapas de Google a Google Earth, etc. Aquí, convertimos fachadas de edificios en edificios reales.

Por ejemplo, utilizaremos la base de datos de fachadas CMP , proveída por el Centro de Percepción de Máquinas en la Universidad Técnica Checa en Praga . Para mantener nuestro ejemplo breve, utilizaremos una copia preprocesada de este conjunto de datos, creada por los autores del artículo anterior.

Cada época toma alrededor de 15 segundos en una sola GPU V100.

A continuación se muestra el resultado generado después de entrenar el modelo durante 200 épocas.

salida de muestra_1salida de muestra_2

Importar TensorFlow y otras bibliotecas

 import tensorflow as tf

import os
import time

from matplotlib import pyplot as plt
from IPython import display
 
pip install -q -U tensorboard
ERROR: tensorflow 2.2.0 has requirement tensorboard<2.3.0,>=2.2.0, but you'll have tensorboard 2.3.0 which is incompatible.

Cargue el conjunto de datos

Puede descargar este conjunto de datos y conjuntos de datos similares desde aquí . Como se menciona en el documento , aplicamos fluctuaciones y reflejos aleatorios al conjunto de datos de entrenamiento.

  • En jittering aleatorio, la imagen se redimensiona a 286 x 286 y luego se recorta aleatoriamente a 256 x 256
  • En la duplicación aleatoria, la imagen se voltea aleatoriamente horizontalmente, es decir, de izquierda a derecha.
 _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
 
Downloading data from https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz
30171136/30168306 [==============================] - 2s 0us/step

 BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
 
 def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image
 
 inp, re = load(PATH+'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)
 
<matplotlib.image.AxesImage at 0x7f9b575231d0>

png

png

 def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image
 
 def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]
 
 # normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image
 
 @tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image
 

Como puede ver en las imágenes a continuación, están pasando por fluctuaciones aleatorias Las fluctuaciones aleatorias como se describe en el documento es para

  1. Cambiar el tamaño de una imagen a mayor altura y ancho
  2. Recortar aleatoriamente al tamaño objetivo
  3. Voltear aleatoriamente la imagen horizontalmente
 plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.axis('off')
plt.show()
 

png

 def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
 
 def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
 

Tubería de entrada

 train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
 
 test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)
 

Construye el generador

  • La arquitectura del generador es una U-Net modificada.
  • Cada bloque en el codificador es (Conv -> Batchnorm -> Leaky ReLU)
  • Cada bloque en el decodificador es (Transposed Conv -> Batchnorm -> Dropout (aplicado a los primeros 3 bloques) -> ReLU)
  • Hay conexiones de omisión entre el codificador y el decodificador (como en U-Net).
 OUTPUT_CHANNELS = 3
 
 def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result
 
 down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
 
(1, 128, 128, 3)

 def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result
 
 up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)
 
(1, 256, 256, 3)

 def Generator():
  inputs = tf.keras.layers.Input(shape=[256,256,3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128, 4), # (bs, 64, 64, 128)
    downsample(256, 4), # (bs, 32, 32, 256)
    downsample(512, 4), # (bs, 16, 16, 512)
    downsample(512, 4), # (bs, 8, 8, 512)
    downsample(512, 4), # (bs, 4, 4, 512)
    downsample(512, 4), # (bs, 2, 2, 512)
    downsample(512, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, 4), # (bs, 16, 16, 1024)
    upsample(256, 4), # (bs, 32, 32, 512)
    upsample(128, 4), # (bs, 64, 64, 256)
    upsample(64, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)
 
 generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
 

png

 gen_output = generator(inp[tf.newaxis,...], training=False)
plt.imshow(gen_output[0,...])
 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

<matplotlib.image.AxesImage at 0x7f9adf885940>

png

  • Pérdida del generador
    • Es una pérdida de entropía cruzada sigmoidea de las imágenes generadas y una serie de imágenes.
    • El documento también incluye la pérdida de L1 que es MAE (error absoluto medio) entre la imagen generada y la imagen objetivo.
    • Esto permite que la imagen generada se vuelva estructuralmente similar a la imagen objetivo.
    • La fórmula para calcular la pérdida total del generador = gan_loss + LAMBDA * l1_loss, donde LAMBDA = 100. Este valor fue decidido por los autores del artículo .

El procedimiento de entrenamiento para el generador se muestra a continuación:

 LAMBDA = 100
 
 def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss
 

Imagen de actualización del generador

Construye el discriminador

  • El discriminador es un PatchGAN.
  • Cada bloque en el discriminador es (Conv -> BatchNorm -> Leaky ReLU)
  • La forma de la salida después de la última capa es (batch_size, 30, 30, 1)
  • Cada parche de 30x30 de la salida clasifica una porción de 70x70 de la imagen de entrada (dicha arquitectura se llama PatchGAN).
  • El discriminador recibe 2 entradas.
    • Imagen de entrada y la imagen de destino, que debe clasificar como real.
    • Imagen de entrada y la imagen generada (salida del generador), que debe clasificar como falsa.
    • Concatenamos estas 2 entradas juntas en el código ( tf.concat([inp, tar], axis=-1) )
 def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)
 
 discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)
 

png

 disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
 
<matplotlib.colorbar.Colorbar at 0x7f9a6c17bcc0>

png

Pérdida discriminadora

  • La función de pérdida discriminadora toma 2 entradas; imágenes reales, imágenes generadas
  • real_loss es una pérdida de entropía cruzada sigmoidea de las imágenes reales y una serie de imágenes (ya que estas son las imágenes reales)
  • generate_loss es una pérdida de entropía cruzada sigmoidea de las imágenes generadas y una matriz de ceros (ya que estas son las imágenes falsas)
  • Entonces el total_loss es la suma de real_loss y el generate_loss
 loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
 
 def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss
 

El procedimiento de entrenamiento para el discriminador se muestra a continuación.

Para obtener más información sobre la arquitectura y los hiperparámetros, puede consultar el documento .

Imagen de actualización del discriminador

Definir los optimizadores y el punto de control

 generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
 
 checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
 

Generar imágenes

Escribe una función para trazar algunas imágenes durante el entrenamiento.

  • Pasamos imágenes del conjunto de datos de prueba al generador.
  • El generador luego traducirá la imagen de entrada a la salida.
  • El último paso es trazar las predicciones y ¡listo!
 def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
 
 for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)
 

png

Formación

  • Para cada ejemplo, la entrada genera una salida.
  • El discriminador recibe input_image y la imagen generada como la primera entrada. La segunda entrada es input_image y target_image.
  • A continuación, calculamos el generador y la pérdida discriminadora.
  • Luego, calculamos los gradientes de pérdida con respecto a las variables (entradas) del generador y del discriminador y las aplicamos al optimizador.
  • Luego registre las pérdidas en TensorBoard.
 EPOCHS = 150
 
 import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
 
 @tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)
 

El circuito de entrenamiento real:

  • Itera sobre el número de épocas.
  • En cada época, borra la pantalla y ejecuta generate_images para mostrar su progreso.
  • En cada época, itera sobre el conjunto de datos de entrenamiento, imprimiendo un '.' para cada ejemplo
  • Se ahorra un punto de control cada 20 épocas.
 def fit(train_ds, epochs, test_ds):
  for epoch in range(epochs):
    start = time.time()

    display.clear_output(wait=True)

    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    print("Epoch: ", epoch)

    # Train
    for n, (input_image, target) in train_ds.enumerate():
      print('.', end='')
      if (n+1) % 100 == 0:
        print()
      train_step(input_image, target, epoch)
    print()

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  checkpoint.save(file_prefix = checkpoint_prefix)
 

Este ciclo de entrenamiento guarda registros que puede ver fácilmente en TensorBoard para monitorear el progreso del entrenamiento. Trabajando localmente lanzarías un proceso separado de tensorboard. En un cuaderno, si desea monitorear con TensorBoard, es más fácil iniciar el visor antes de comenzar el entrenamiento.

Para iniciar el visor, pegue lo siguiente en una celda de código:

 %load_ext tensorboard
%tensorboard --logdir {log_dir}
 

Ahora ejecuta el ciclo de entrenamiento:

 fit(train_dataset, EPOCHS, test_dataset)
 

png

Epoch:  125
....................................................................................................
.......................................................................................

Si desea compartir los resultados de TensorBoard públicamente , puede cargar los registros en TensorBoard.dev copiando lo siguiente en una celda de código.

tensorboard dev upload --logdir  {log_dir}

Puede ver los resultados de una ejecución anterior de este cuaderno en TensorBoard.dev .

TensorBoard.dev es una experiencia administrada para alojar, rastrear y compartir experimentos de ML con todos.

También se puede incluir en línea usando un <iframe> :

 display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")
 

Interpretar los registros de una GAN es más sutil que una simple clasificación o modelo de regresión. Cosas a buscar ::

  • Verifique que ninguno de los modelos haya "ganado". Si gen_gan_loss o disc_loss muy bajos, es un indicador de que este modelo está dominando al otro, y no está entrenando con éxito el modelo combinado.
  • El valor log(2) = 0.69 es un buen punto de referencia para estas pérdidas, ya que indica una perplejidad de 2: que el discriminador es en promedio igualmente incierto acerca de las dos opciones.
  • Para disc_loss un valor inferior a 0.69 significa que el discriminador funciona mejor que al azar, en el conjunto combinado de imágenes reales + generadas.
  • Para gen_gan_loss un valor inferior a 0.69 significa que el generador funciona mejor que aleatoriamente para eliminar el descriminador.
  • A medida que avanza el gen_l1_loss , gen_l1_loss debería disminuir.

Restaurar el último punto de control y prueba

ls {checkpoint_dir}
checkpoint          ckpt-5.data-00000-of-00002
ckpt-1.data-00000-of-00002  ckpt-5.data-00001-of-00002
ckpt-1.data-00001-of-00002  ckpt-5.index
ckpt-1.index            ckpt-6.data-00000-of-00002
ckpt-2.data-00000-of-00002  ckpt-6.data-00001-of-00002
ckpt-2.data-00001-of-00002  ckpt-6.index
ckpt-2.index            ckpt-7.data-00000-of-00002
ckpt-3.data-00000-of-00002  ckpt-7.data-00001-of-00002
ckpt-3.data-00001-of-00002  ckpt-7.index
ckpt-3.index            ckpt-8.data-00000-of-00002
ckpt-4.data-00000-of-00002  ckpt-8.data-00001-of-00002
ckpt-4.data-00001-of-00002  ckpt-8.index
ckpt-4.index

 # restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
 
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9a6c01e978>

Generar utilizando el conjunto de datos de prueba

 # Run the trained model on a few examples from the test dataset
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)
 

png

png

png

png

png