Merken Sie den Termin vor! Google I / O kehrt vom 18. bis 20. Mai zurück Registrieren Sie sich jetzt
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Super Auflösung mit TensorFlow Lite

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen Siehe TF Hub-Modell

Überblick

Die Aufgabe, ein Bild mit hoher Auflösung (HR) von seinem Gegenstück mit niedriger Auflösung wiederherzustellen, wird üblicherweise als Einzelbild-Superauflösung (SISR) bezeichnet.

Das hier verwendete Modell ist ESRGAN ( ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks ). Und wir werden TensorFlow Lite verwenden, um Rückschlüsse auf das vorab trainierte Modell zu ziehen.

Das TFLite-Modell wird aus dieser auf TF Hub gehosteten Implementierung konvertiert. Beachten Sie, dass das Modell, das wir konvertiert haben, ein Bild mit niedriger Auflösung von 50 x 50 in ein Bild mit hoher Auflösung von 200 x 200 konvertiert (Skalierungsfaktor = 4). Wenn Sie eine andere Eingabegröße oder einen anderen Skalierungsfaktor wünschen, müssen Sie das ursprüngliche Modell neu konvertieren oder neu trainieren.

Einrichten

Lassen Sie uns zuerst die erforderlichen Bibliotheken installieren.

pip install -q matplotlib tensorflow tensorflow-hub

Abhängigkeiten importieren.

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
2.4.1

Laden Sie das ESRGAN-Modell herunter und konvertieren Sie es

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 50, 50, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'

Laden Sie ein Testbild (Insektenkopf) herunter.

test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
Downloading data from https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg
8192/6432 [======================================] - 0s 0us/step

Erstellen Sie mit TensorFlow Lite ein Bild mit hoher Auflösung

lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)

Visualisieren Sie das Ergebnis

lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)        
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());

bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());
<matplotlib.image.AxesImage at 0x7f99dad41588>

png

png

Leistungsbenchmarks

Leistungsbenchmarkzahlen werden mit dem hier beschriebenen Tool generiert.

Modellname Modellgröße Gerät Zentralprozessor GPU
Superauflösung (ESRGAN) 4,8 Mb Pixel 3 586,8 ms * 128,6 ms
Pixel 4 385,1 ms * 130,3 ms

* 4 Threads verwendet