Image Super Resolution using ESRGAN

View on Run in Google Colab View on GitHub Download notebook See TF Hub model

This colab demonstrates use of TensorFlow Hub Module for Enhanced Super Resolution Generative Adversarial Network (by Xintao Wang [Paper] [Code]

for image enhancing. (Preferrably bicubically downsampled images).

Model trained on DIV2K Dataset (on bicubically downsampled images) on image patches of size 128 x 128.

Preparing Environment

import os
import time
from PIL import Image
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
wget "" -O original.png
# Declaring Constants
IMAGE_PATH = "original.png"

Defining Helper Functions

def preprocess_image(image_path):
  """ Loads image from path and preprocesses to make it model ready
        image_path: Path to the image file
  hr_image = tf.image.decode_image(
  # If PNG, remove the alpha channel. The model only supports
  # images with 3 color channels.
  if hr_image.shape[-1] == 4:
    hr_image = hr_image[...,:-1]
  hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
  hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
  hr_image = tf.cast(hr_image, tf.float32)
  return tf.expand_dims(hr_image, 0)

def save_image(image, filename):
    Saves unscaled Tensor Images.
      image: 3D image tensor. [height, width, channels]
      filename: Name of the file to save.
  if not isinstance(image, Image.Image):
    image = tf.clip_by_value(image, 0, 255)
    image = Image.fromarray(tf.cast(image, tf.uint8).numpy())"%s.jpg" % filename)
  print("Saved as %s.jpg" % filename)
%matplotlib inline
def plot_image(image, title=""):
    Plots images from image tensors.
      image: 3D image tensor. [height, width, channels].
      title: Title to display in the plot.
  image = np.asarray(image)
  image = tf.clip_by_value(image, 0, 255)
  image = Image.fromarray(tf.cast(image, tf.uint8).numpy())

Performing Super Resolution of images loaded from path

hr_image = preprocess_image(IMAGE_PATH)
# Plotting Original Resolution image
plot_image(tf.squeeze(hr_image), title="Original Image")
save_image(tf.squeeze(hr_image), filename="Original Image")
model = hub.load(SAVED_MODEL_PATH)
start = time.time()
fake_image = model(hr_image)
fake_image = tf.squeeze(fake_image)
print("Time Taken: %f" % (time.time() - start))
# Plotting Super Resolution Image
plot_image(tf.squeeze(fake_image), title="Super Resolution")
save_image(tf.squeeze(fake_image), filename="Super Resolution")

Evaluating Performance of the Model

!wget "" -O test.jpg
IMAGE_PATH = "test.jpg"
# Defining helper functions
def downscale_image(image):
      Scales down images using bicubic downsampling.
          image: 3D or 4D tensor of preprocessed image
  image_size = []
  if len(image.shape) == 3:
    image_size = [image.shape[1], image.shape[0]]
    raise ValueError("Dimension mismatch. Can work only on single image.")

  image = tf.squeeze(
          tf.clip_by_value(image, 0, 255), tf.uint8))

  lr_image = np.asarray(
    .resize([image_size[0] // 4, image_size[1] // 4],

  lr_image = tf.expand_dims(lr_image, 0)
  lr_image = tf.cast(lr_image, tf.float32)
  return lr_image
hr_image = preprocess_image(IMAGE_PATH)
lr_image = downscale_image(tf.squeeze(hr_image))
# Plotting Low Resolution Image
plot_image(tf.squeeze(lr_image), title="Low Resolution")
model = hub.load(SAVED_MODEL_PATH)
start = time.time()
fake_image = model(lr_image)
fake_image = tf.squeeze(fake_image)
print("Time Taken: %f" % (time.time() - start))
plot_image(tf.squeeze(fake_image), title="Super Resolution")
# Calculating PSNR wrt Original Image
psnr = tf.image.psnr(
    tf.clip_by_value(fake_image, 0, 255),
    tf.clip_by_value(hr_image, 0, 255), max_val=255)
print("PSNR Achieved: %f" % psnr)

Comparing Outputs size by side.

plt.rcParams['figure.figsize'] = [15, 10]
fig, axes = plt.subplots(1, 3)
plot_image(tf.squeeze(hr_image), title="Original")
plot_image(tf.squeeze(lr_image), "x4 Bicubic")
plot_image(tf.squeeze(fake_image), "Super Resolution")
plt.savefig("ESRGAN_DIV2K.jpg", bbox_inches="tight")
print("PSNR: %f" % psnr)