Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier | Voir le modèle TF Hub |
Aperçu
La tâche consistant à récupérer une image haute résolution (HR) à partir de son homologue basse résolution est communément appelée Single Image Super Resolution (SISR).
Le modèle utilisé ici est ESRGAN ( ESRGAN: Super-résolution améliorée génératives accusatoires Networks ). Et nous allons utiliser TensorFlow Lite pour exécuter l'inférence sur le modèle pré-entraîné.
Le modèle TFLite est converti à partir de cette mise en œuvre hébergé sur TF Hub. Notez que le modèle que nous avons converti suréchantillonne une image basse résolution 50x50 en une image haute résolution 200x200 (facteur d'échelle=4). Si vous souhaitez une taille d'entrée ou un facteur d'échelle différent, vous devez reconvertir ou réentraîner le modèle d'origine.
Installer
Installons d'abord les bibliothèques requises.
pip install matplotlib tensorflow tensorflow-hub
Importer des dépendances.
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0
Téléchargez et convertissez le modèle ESRGAN
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
return concrete_func(input);
converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
f.write(tflite_model)
esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets 2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
Téléchargez une image test (tête d'insecte).
test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg 16384/6432 [============================================================================] - 0s 0us/step
Générez une image en super résolution à l'aide de TensorFlow Lite
lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()
# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
Visualisez le résultat
lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());
bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)
plt.title('Bicubic')
plt.imshow(bicubic.numpy());
Références de performance
Les numéros de référence de performance sont générés avec l'outil décrit ici .
Nom du modèle | Taille du modèle | Appareil | CPU | GPU |
---|---|---|---|---|
super résolution (ESRGAN) | 4,8 Mo | Pixel 3 | 586,8 ms* | 128,6 ms |
Pixel 4 | 385,1 ms* | 130.3ms |
* 4 fils utilisés