הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת | ראה דגם TF Hub |
סקירה כללית
המשימה של שחזור תמונה ברזולוציה גבוהה (HR) ממקבילה ברזולוציה נמוכה מכונה בדרך כלל רזולוציית סופר של תמונה אחת (SISR).
המודל המשמש כאן הוא ESRGAN ( ESRGAN: רשתות תשובה: Generative Enhanced Super-Resolution ). ואנחנו הולכים להשתמש ב-TensorFlow Lite כדי להסיק מסקנות על המודל שהוכשר מראש.
מודל TFLite מומר מכך יישום מתארח Hub TF. שימו לב שהדגם שהמרנו מעלה תמונה ברזולוציה נמוכה של 50x50 לתמונה ברזולוציה גבוהה של 200x200 (גורם קנה מידה=4). אם אתה רוצה גודל קלט או גורם קנה מידה שונה, עליך להמיר מחדש או לאמן מחדש את הדגם המקורי.
להכין
בוא נתקין תחילה ספריות נדרשות.
pip install matplotlib tensorflow tensorflow-hub
תלות בייבוא.
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0
הורד והמר את מודל ESRGAN
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
return concrete_func(input);
converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
f.write(tflite_model)
esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets 2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
הורד תמונת בדיקה (ראש חרק).
test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg 16384/6432 [============================================================================] - 0s 0us/step
צור תמונה ברזולוציית סופר באמצעות TensorFlow Lite
lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()
# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
דמיינו את התוצאה
lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());
bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)
plt.title('Bicubic')
plt.imshow(bicubic.numpy());
מדדי ביצועים
מספרי benchmark ביצועים נוצרים עם הכלי המתואר כאן .
שם המודל | גודל דגם | התקן | מעבד | GPU |
---|---|---|---|---|
רזולוציית סופר (ESRGAN) | 4.8 מגה-ביט | פיקסל 3 | 586.8ms* | 128.6ms |
פיקסל 4 | 385.1ms* | 130.3ms |
* 4 אשכולות בשימוש