Super risoluzione con TensorFlow Lite

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino Vedi il modello del mozzo TF

Panoramica

Il compito di recuperare un'immagine ad alta risoluzione (HR) dalla sua controparte a bassa risoluzione è comunemente indicato come Single Image Super Resolution (SISR).

Il modello utilizzato è ESRGAN ( ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks ). E useremo TensorFlow Lite per eseguire l'inferenza sul modello preaddestrato.

Il modello TFLite viene convertito da questa implementazione ospitato su TF Hub. Nota che il modello che abbiamo convertito sovracampiona un'immagine a bassa risoluzione 50x50 in un'immagine ad alta risoluzione 200x200 (fattore di scala=4). Se desideri una dimensione di input o un fattore di scala diverso, devi riconvertire o riaddestrare il modello originale.

Impostare

Installiamo prima le librerie richieste.

pip install matplotlib tensorflow tensorflow-hub

Importa le dipendenze.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0

Scarica e converti il ​​modello ESRGAN

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Scarica un'immagine di prova (testa di insetto).

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
16384/6432 [============================================================================] - 0s 0us/step

Genera un'immagine ad alta risoluzione utilizzando TensorFlow Lite

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

Visualizza il risultato

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

Benchmark delle prestazioni

I numeri di riferimento delle prestazioni sono generati con lo strumento qui descritto .

Nome del modello Dimensioni del modello Dispositivo processore GPU
super risoluzione (ESRGAN) 4.8 Mb Pixel 3 586,8 ms* 128,6 ms
Pixel 4 385,1 ms* 130.3ms

* 4 fili utilizzati