질문이있다? TensorFlow 포럼 방문 포럼 에서 커뮤니티와 연결

TensorFlow Lite를 사용한 초 고해상도

TensorFlow.org에서보기 Google Colab에서 실행 GitHub에서 소스보기 노트북 다운로드 TF Hub 모델보기

개요

저해상도 대응 물에서 고해상도 (HR) 이미지를 복구하는 작업을 일반적으로 단일 이미지 초 고해상도 (SISR)라고합니다.

여기에 사용 된 모델은 ESRGAN ( ESRGAN : Enhanced Super-Resolution Generative Adversarial Networks )입니다. 그리고 TensorFlow Lite를 사용하여 사전 훈련 된 모델에 대한 추론을 실행할 것입니다.

TFLite 모델은 TF Hub에서 호스팅되는이 구현 에서 변환됩니다. 변환 한 모델은 50x50 저해상도 이미지를 200x200 고해상도 이미지로 업 샘플링합니다 (배율 계수 = 4). 다른 입력 크기 또는 축척 비율을 원하는 경우 원래 모델을 다시 변환하거나 다시 훈련해야합니다.

설정

먼저 필요한 라이브러리를 설치하겠습니다.

pip install -q matplotlib tensorflow tensorflow-hub

종속성을 가져옵니다.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.5.0

ESRGAN 모델 다운로드 및 변환

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 50, 50, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'

테스트 이미지 (곤충 머리)를 다운로드합니다.

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
8192/6432 [======================================] - 0s 0us/step

TensorFlow Lite를 사용하여 초 고해상도 이미지 생성

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

결과 시각화

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());
<matplotlib.image.AxesImage at 0x7fb46bd35210>

png

png

성능 벤치 마크

여기에 설명 된 도구를 사용하여 성능 벤치 마크 수치가 생성 됩니다 .

모델명 모델 사이즈 장치 CPU GPU
초 고해상도 (ESRGAN) 4.8Mb Pixel 3 586.8ms * 128.6ms
Pixel 4 385.1ms * 130.3ms

* 사용 된 스레드 4 개