Trasferimento di stile artistico con TensorFlow Lite

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino Vedi il modello del mozzo TF

Uno dei più interessanti sviluppi in profondità di apprendimento di uscire di recente è il trasferimento artistico stile , o la possibilità di creare una nuova immagine, conosciuto come un pastiche , sulla base di due immagini in ingresso: uno che rappresenta lo stile artistico e uno che rappresenta il contenuto.

Esempio di trasferimento di stile

Usando questa tecnica, possiamo generare bellissime nuove opere d'arte in una gamma di stili.

Esempio di trasferimento di stile

Se non conosci TensorFlow Lite e stai lavorando con Android, ti consigliamo di esplorare le seguenti applicazioni di esempio che possono aiutarti a iniziare.

Esempio Android esempio iOS

Se si utilizza una piattaforma diversa da Android o iOS, o si ha già familiarità con i Lite API tensorflow , è possibile seguire questo tutorial per imparare ad applicare il trasferimento stile su una qualsiasi coppia di stile di immagine contenuti e con un pre-addestrato tensorflow Lite modello. Puoi utilizzare il modello per aggiungere il trasferimento di stile alle tue applicazioni mobili.

Il modello è open-source su GitHub . È possibile riqualificare il modello con parametri diversi (ad esempio, aumentare i pesi dei livelli di contenuto per rendere l'immagine di output più simile all'immagine di contenuto).

Comprendere l'architettura del modello

Architettura del modello

Questo modello Artistic Style Transfer è costituito da due sottomodelli:

  1. Stile Prediciton Modello: A-MobilenetV2 basa rete neurale che prende un'immagine stile di input ad un 100-dimensione stile collo di bottiglia vettore.
  2. Stile Transform Modello: Una rete neurale che prende applicare un vettore stile collo di bottiglia a un'immagine contenuti e crea un'immagine stilizzata.

Se la tua app deve supportare solo un set fisso di immagini di stile, puoi calcolare in anticipo i relativi vettori del collo di bottiglia dello stile ed escludere il modello di previsione dello stile dal file binario della tua app.

Impostare

Importa le dipendenze.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Scarica il contenuto e le immagini di stile e i modelli TensorFlow Lite pre-addestrati.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

Pre-elaborare gli input

  • L'immagine del contenuto e l'immagine dello stile devono essere immagini RGB con valori di pixel che sono numeri float32 compresi tra [0..1].
  • La dimensione dell'immagine dello stile deve essere (1, 256, 256, 3). Ritagliamo centralmente l'immagine e la ridimensioniamo.
  • L'immagine del contenuto deve essere (1, 384, 384, 3). Ritagliamo centralmente l'immagine e la ridimensioniamo.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Visualizza gli ingressi

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Esegui il trasferimento dello stile con TensorFlow Lite

Previsione dello stile

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Trasformazione dello stile

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Miscelazione di stile

Possiamo fondere lo stile dell'immagine del contenuto nell'output stilizzato, che a sua volta rende l'output più simile all'immagine del contenuto.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Benchmark delle prestazioni

I numeri di riferimento delle prestazioni sono generati con lo strumento qui descritto .

Nome del modello Dimensioni del modello Dispositivo NNAPI processore GPU
Modello di previsione dello stile (int8) 2,8 Mb Pixel 3 (Android 10) 142 ms 14 ms
Pixel 4 (Android 10) 5.2ms 6,7 ms
iPhone XS (iOS 12.4.1) 10,7 ms
Modello di trasformazione dello stile (int8) 0.2 Mb Pixel 3 (Android 10) 540 ms
Pixel 4 (Android 10) 405 ms
iPhone XS (iOS 12.4.1) 251 ms
Modello di previsione dello stile (float16) 4.7 Mb Pixel 3 (Android 10) 86 ms 28 ms 9.1ms
Pixel 4 (Android 10) 32 ms 12 ms 10ms
Modello di trasferimento dello stile (float16) 0,4 Mb Pixel 3 (Android 10) 1095 ms 545 ms 42 ms
Pixel 4 (Android 10) 603 ms 377 ms 42 ms

* 4 fili utilizzati.
** 2 thread su iPhone per le migliori prestazioni.