Bantuan melindungi Great Barrier Reef dengan TensorFlow pada Kaggle Bergabung Tantangan

Resolusi super dengan TensorFlow Lite

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan Lihat model Hub TF

Gambaran

Tugas memulihkan gambar resolusi tinggi (HR) dari rekan resolusi rendahnya biasanya disebut sebagai Resolusi Super Gambar Tunggal (SISR).

Model yang digunakan di sini adalah ESRGAN ( ESRGAN: Peningkatan Super Resolution generatif Adversarial Networks ). Dan kita akan menggunakan TensorFlow Lite untuk menjalankan inferensi pada model yang telah dilatih sebelumnya.

Model TFLite dikonversi dari ini implementasi host di TF Hub. Perhatikan bahwa model yang kami konversi meningkatkan sampel gambar resolusi rendah 50x50 menjadi gambar resolusi tinggi 200x200 (faktor skala=4). Jika Anda menginginkan ukuran input atau faktor skala yang berbeda, Anda perlu mengonversi ulang atau melatih ulang model aslinya.

Mempersiapkan

Mari kita instal perpustakaan yang diperlukan terlebih dahulu.

pip install matplotlib tensorflow tensorflow-hub

Impor dependensi.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.7.0

Unduh dan ubah model ESRGAN

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 335). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
INFO:tensorflow:Assets written to: /tmp/tmpinlbbz0t/assets
2021-11-16 12:15:19.621471: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2021-11-16 12:15:19.621517: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Unduh gambar uji (kepala serangga).

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
16384/6432 [============================================================================] - 0s 0us/step

Hasilkan gambar resolusi super menggunakan TensorFlow Lite

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

Visualisasikan hasilnya

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

Tolok Ukur Kinerja

Nomor tolok ukur kinerja yang dihasilkan dengan alat yang dijelaskan di sini .

Nama model Ukuran Model Perangkat CPU GPU
resolusi super (ESRGAN) 4,8 Mb Piksel 3 586.8ms* 128,6 ms
Piksel 4 385.1ms* 130.3ms

* 4 benang yang digunakan