Künstlerische Stilübertragung mit TensorFlow Lite

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen Siehe TF Hub-Modell

Eine der interessantesten Entwicklungen in der tiefen Lernen ist vor kurzem zu kommen künstlerischen Stil Transfer oder die Möglichkeit , ein neues Bild zu erstellen, als bekannt Pastiche , basierend auf zwei Eingabebildern: eine , die den künstlerischen Stil und eine den Inhalt darstellt.

Beispiel für eine Stilübertragung

Mit dieser Technik können wir wunderschöne neue Kunstwerke in einer Reihe von Stilen erstellen.

Beispiel für eine Stilübertragung

Wenn Sie neu bei TensorFlow Lite sind und mit Android arbeiten, empfehlen wir Ihnen, sich die folgenden Beispielanwendungen anzusehen, die Ihnen den Einstieg erleichtern können.

Android Beispiel iOS Beispiel

Wenn Sie eine andere Plattform als Android oder iOS verwenden, oder Sie sind bereits vertraut mit der TensorFlow Lite - APIs können Sie dieses Tutorial folgen zu lernen , wie man mit einem vortrainiert TensorFlow Lite Stil Übertragung auf jedes Paar von Inhalt und Stil Bild anwenden Modell. Sie können das Modell verwenden, um Stilübertragungen zu Ihren eigenen mobilen Anwendungen hinzuzufügen.

Das Modell ist als Open-Source auf GitHub . Sie können das Modell mit verschiedenen Parametern neu trainieren (z. B. die Gewichtung der Inhaltsebenen erhöhen, damit das Ausgabebild dem Inhaltsbild ähnlicher wird).

Verstehen Sie die Modellarchitektur

Modellarchitektur

Dieses Artistic Style Transfer-Modell besteht aus zwei Untermodellen:

  1. Style Prediciton Model: A MobilenetV2 basierte neuronale Netzwerk , das ein Eingangsbild Stil zu einem 100-Bemaßungsstils Engpaß Vektor nimmt.
  2. Style - Transform Modell: Ein neuronales Netzwerk , das einen Stil Engpass Vektor zu einem Inhalts Bild nimmt anwenden und erzeugt ein stilisiertes Bild.

Wenn Ihre App nur einen festen Satz von Stilbildern unterstützen muss, können Sie deren Stilengpassvektoren im Voraus berechnen und das Stilvorhersagemodell aus der Binärdatei Ihrer App ausschließen.

Installieren

Abhängigkeiten importieren.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Laden Sie die Inhalts- und Stilbilder sowie die vortrainierten TensorFlow Lite-Modelle herunter.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

Die Eingaben vorverarbeiten

  • Das Inhaltsbild und das Stilbild müssen RGB-Bilder sein, deren Pixelwerte float32-Zahlen zwischen [0..1] sind.
  • Die Stilbildgröße muss (1, 256, 256, 3) sein. Wir beschneiden das Bild zentral und ändern die Größe.
  • Das Inhaltsbild muss (1, 384, 384, 3) sein. Wir beschneiden das Bild zentral und ändern die Größe.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)
2021-08-12 11:20:05.704558: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:05.712948: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:05.713842: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:05.715823: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-12 11:20:05.716431: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:05.717356: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:05.718291: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:06.305180: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:06.306111: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:06.306955: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-12 11:20:06.307815: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14648 MB memory:  -> device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0

Visualisieren Sie die Eingaben

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Führen Sie die Stilübertragung mit TensorFlow Lite durch

Stilvorhersage

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Stiltransformation

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Stilmischung

Wir können den Stil des Inhaltsbilds in die stilisierte Ausgabe einfügen, wodurch die Ausgabe wiederum dem Inhaltsbild ähnelt.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Leistungsbenchmarks

Performance - Benchmark - Nummern werden mit dem Werkzeug erzeugt hier beschrieben .

Modellname Modellgröße Gerät NNAPI Zentralprozessor GPU
Modell zur Stilvorhersage (int8) 2,8 MB Pixel 3 (Android 10) 142 ms 14ms
Pixel 4 (Android 10) 5,2 ms 6,7 ms
iPhone XS (iOS 12.4.1) 10,7 ms
Stiltransformationsmodell (int8) 0,2 MB Pixel 3 (Android 10) 540ms
Pixel 4 (Android 10) 405ms
iPhone XS (iOS 12.4.1) 251ms
Modell zur Stilvorhersage (float16) 4,7 MB Pixel 3 (Android 10) 86ms 28ms 9,1 ms
Pixel 4 (Android 10) 32ms 12ms 10ms
Stilübertragungsmodell (float16) 0,4 MB Pixel 3 (Android 10) 1095ms 545ms 42ms
Pixel 4 (Android 10) 603 ms 377ms 42ms

* 4 Threads verwendet.
** 2 Threads auf dem iPhone für die beste Leistung.