Độ phân giải siêu cao với TensorFlow Lite

Sử dụng bộ sưu tập để sắp xếp ngăn nắp các trang Lưu và phân loại nội dung dựa trên lựa chọn ưu tiên của bạn.

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép Xem mô hình TF Hub

Tổng quat

Nhiệm vụ khôi phục hình ảnh có độ phân giải cao (HR) từ đối tác có độ phân giải thấp thường được gọi là Ảnh siêu phân giải đơn (SISR).

Mô hình sử dụng ở đây là ESRGAN ( ESRGAN: Enhanced Super-Resolution Generative gây tranh cãi Networks ). Và chúng tôi sẽ sử dụng TensorFlow Lite để chạy suy luận trên mô hình được đào tạo trước.

Mô hình TFLite được chuyển đổi từ này thực hiện được lưu trữ trên TF Hub. Lưu ý rằng mô hình chúng tôi đã chuyển đổi ví dụ hình ảnh có độ phân giải thấp 50x50 thành hình ảnh có độ phân giải cao 200x200 (hệ số tỷ lệ = 4). Nếu bạn muốn kích thước đầu vào hoặc hệ số tỷ lệ khác, bạn cần chuyển đổi lại hoặc đào tạo lại mô hình ban đầu.

Thành lập

Trước tiên, hãy cài đặt các thư viện bắt buộc.

pip install matplotlib tensorflow tensorflow-hub

Nhập phụ thuộc.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0

Tải xuống và chuyển đổi mô hình ESRGAN

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Tải xuống hình ảnh thử nghiệm (đầu côn trùng).

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
16384/6432 [============================================================================] - 0s 0us/step

Tạo hình ảnh có độ phân giải siêu cao bằng TensorFlow Lite

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

Hình dung kết quả

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

Điểm chuẩn hiệu suất

Số benchmark hiệu năng được tạo ra với công cụ mô tả ở đây .

Tên Model Kích thước mô hình Thiết bị CPU GPU
siêu phân giải (ESRGAN) 4,8 Mb Pixel 3 586,8ms * 128,6ms
Pixel 4 385.1ms * 130.3ms

* 4 đề sử dụng