TensorFlow Lite の超解像技術

TensorFlow.org で表示 Google Colab で実行 GitHubでソースを表示 ノートブックをダウンロード TF Hub モデルを参照

概要

一般的に、低解像度の画像から高解像度 (HR) の画像を回復する作業は、Single Image Super Resolution (SISR) と呼ばれます。

ここで使用されるモデルは ESRGAN (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) です。TensorFlow Lite を使用して、あらかじめトレーニングされたモデルで推論を実行します。

TFLite モデルは、TF ハブでホスティングされているこの実装から変換されます。変換されたモデルは、50 x 50 の低解像度画像を 200 x 200 (倍率 4 倍) の高解像度画像にアップサンプリングします。別の入力サイズまたは倍率を使用する場合は、元のモデルを再変換するか、再トレーニングする必要があります。

MNIST モデルをビルドする

まず、必要なライブラリをインストールします。

pip install matplotlib tensorflow tensorflow-hub

依存関係をインポートします。

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2024-01-11 17:49:54.063947: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:49:54.063990: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:49:54.065444: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2.15.0

ESRGAN モデルをダウンロードして変換します。

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
  return concrete_func(input);

converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpo82w2j1z/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpo82w2j1z/assets
2024-01-11 17:50:16.024291: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-01-11 17:50:16.024331: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 176, Total Ops 845, % non-converted = 20.83 %
 * 176 ARITH ops

- arith.constant:  176 occurrences  (f32: 174, i32: 2)



  (f32: 45)
  (f32: 132)
  (f32: 169)
  (f32: 2)
  (f32: 136)
  (f32: 11)
  (uq_8: 169)
  (f32: 2)

テスト画像 (昆虫の頭) をダウンロードします。

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
6432/6432 [==============================] - 0s 0us/step

TensorFlow Lite で超解像度画像を生成する

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.

結果を視覚化する

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

png

png

パフォーマンスベンチマーク

パフォーマンスベンチマークの数値は、ここで説明するツールで生成されます。

モデル名 モデルサイズ デバイス CPU GPU
super resolution (ESRGAN) 4.8 Mb Pixel 3 586.8ms* 128.6ms
Pixel 4 385.1ms* 130.3ms

*{nbsp}4 スレッド使用