ความละเอียดสูงสุดด้วย TensorFlow Lite

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค ดูรุ่น TF Hub

ภาพรวม

งานในการกู้คืนภาพที่มีความละเอียดสูง (HR) จากภาพที่มีความละเอียดต่ำมักเรียกว่า Single Image Super Resolution (SISR)

รูปแบบที่ใช้ที่นี่เป็น ESRGAN ( ESRGAN: ปรับปรุง Super-Resolution กำเนิดเครือข่ายขัดแย้ง ) และเราจะใช้ TensorFlow Lite เพื่อทำการอนุมานบนโมเดลที่ฝึกไว้ล่วงหน้า

รุ่น TFLite ถูกแปลงจากนี้ การดำเนินการ โฮสต์บน TF Hub โปรดทราบว่าโมเดลที่เราแปลงจะอัพตัวอย่างรูปภาพความละเอียดต่ำ 50x50 เป็นรูปภาพความละเอียดสูง 200x200 (ตัวคูณมาตราส่วน=4) หากคุณต้องการขนาดอินพุตหรือตัวคูณมาตราส่วนที่แตกต่างกัน คุณต้องแปลงหรือฝึกโมเดลดั้งเดิมอีกครั้ง

ติดตั้ง

มาติดตั้งไลบรารีที่จำเป็นก่อน

pip install matplotlib tensorflow tensorflow-hub

นำเข้าการอ้างอิง

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2021-07-23 11:17:08.751392: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2.5.0

ดาวน์โหลดและแปลงโมเดล ESRGAN

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 50, 50, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
2021-07-23 11:17:15.750580: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-07-23 11:17:15.754548: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2021-07-23 11:17:15.754584: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: kokoro-gcp-ubuntu-prod-1315497834
2021-07-23 11:17:15.754592: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: kokoro-gcp-ubuntu-prod-1315497834
2021-07-23 11:17:15.754687: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.57.2
2021-07-23 11:17:15.754712: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 465.27.0
2021-07-23 11:17:15.754718: E tensorflow/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 465.27.0 does not match DSO version 470.57.2 -- cannot find working devices in this configuration
2021-07-23 11:17:15.755072: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-07-23 11:17:19.601044: I tensorflow/core/grappler/devices.cc:69] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2021-07-23 11:17:19.601273: I tensorflow/core/grappler/clusters/single_machine.cc:357] Starting new session
2021-07-23 11:17:19.602137: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000175000 Hz
2021-07-23 11:17:19.690642: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1144] Optimization results for grappler item: graph_to_optimize
  function_optimizer: Graph size after: 1953 nodes (1608), 3017 edges (2672), time = 49.882ms.
  function_optimizer: function_optimizer did nothing. time = 1.073ms.

2021-07-23 11:17:21.581037: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:345] Ignored output_format.
2021-07-23 11:17:21.581097: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:348] Ignored drop_control_dependency.
2021-07-23 11:17:21.668047: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:210] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2021-07-23 11:17:21.929411: I tensorflow/lite/tools/optimize/quantize_weights.cc:225] Skipping quantization of tensor model/rrdb_net/conv2d_8/Conv2D;StatefulPartitionedCall/model/rrdb_net/conv2d_8/Conv2D because it has fewer than 1024 elements (864).
2021-07-23 11:17:21.929552: I tensorflow/lite/tools/optimize/quantize_weights.cc:225] Skipping quantization of tensor model/rrdb_net/conv2d_176/Conv2D;StatefulPartitionedCall/model/rrdb_net/conv2d_176/Conv2D because it has fewer than 1024 elements (864).

ดาวน์โหลดภาพทดสอบ (หัวแมลง)

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
8192/6432 [======================================] - 0s 0us/step

สร้างภาพความละเอียดสูงโดยใช้ TensorFlow Lite

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

เห็นภาพผลลัพธ์

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

เกณฑ์มาตรฐานประสิทธิภาพ

หมายเลขมาตรฐานประสิทธิภาพได้รับการสร้างขึ้นด้วยเครื่องมือ อธิบายไว้ที่นี่

ชื่อรุ่น ขนาดรุ่น อุปกรณ์ ซีพียู GPU
ความละเอียดสูงสุด (ESRGAN) 4.8 Mb Pixel 3 586.8ms* 128.6ms
Pixel 4 385.1ms* 130.3ms

* 4 หัวข้อใช้