TensorFlow Lite ile Sanatsal Stil Aktarımı

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir TF Hub modeline bakın

Geçenlerde çıkıp derin öğrenmede en heyecan verici gelişmelerden biri olan sanatsal stil transferi veya olarak bilinen yeni bir görüntü, oluşturma yeteneği pastiş içeriğini temsil eden bir sanatsal tarzını temsil eden ve biri: İki giriş görüntülerine dayanarak,.

Stil aktarımı örneği

Bu tekniği kullanarak, çeşitli tarzlarda güzel yeni sanat eserleri üretebiliriz.

Stil aktarımı örneği

TensorFlow Lite'ta yeniyseniz ve Android ile çalışıyorsanız, başlamanıza yardımcı olabilecek aşağıdaki örnek uygulamaları keşfetmenizi öneririz.

Android örneği iOS örneği

Android veya iOS dışında bir platformu kullanıyorsa veya zaten aşina değilseniz TensorFlow Lite API'ler , bir ön-eğitimli TensorFlow Lite ile içerik ve imgenin çifti tarzı transferini nasıl uygulanacağını öğrenmek için bu öğretici takip edebilirsiniz modeli. Modeli kendi mobil uygulamalarınıza stil aktarımı eklemek için kullanabilirsiniz.

Model açık kaynaklı üzerindedir GitHub'dan . Modeli farklı parametrelerle yeniden eğitebilirsiniz (örneğin, çıktı görüntüsünün içerik görüntüsüne daha çok benzemesi için içerik katmanlarının ağırlıklarını artırın).

Model mimarisini anlayın

Model Mimarisi

Bu Sanatsal Tarz Aktarımı modeli iki alt modelden oluşur:

  1. Stil Prediciton Model: A 100 ölçü stili darboğaz vektöre bir giriş tarzı görüntüsünü alır sinir ağı MobilenetV2 tabanlı.
  2. Stil Model Dönüşümü: Bir içerik görüntüye bir stil darboğaz vektör uygulamak alır ve stilize bir görüntü oluşturur Bir sinir ağı.

Uygulamanızın yalnızca sabit bir stil görüntüleri kümesini desteklemesi gerekiyorsa, stil darboğaz vektörlerini önceden hesaplayabilir ve Stil Tahmin Modelini uygulamanızın ikili programından hariç tutabilirsiniz.

Kurmak

Bağımlılıkları içe aktarın.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

İçerik ve stil resimlerini ve önceden eğitilmiş TensorFlow Lite modellerini indirin.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

Girişleri ön işleme

  • İçerik görüntüsü ve stil görüntüsü, piksel değerleri [0..1] arasında float32 sayıları olan RGB görüntüleri olmalıdır.
  • Stil görüntü boyutu (1, 256, 256, 3) olmalıdır. Resmi ortalayarak kırpıyoruz ve yeniden boyutlandırıyoruz.
  • İçerik resmi (1, 384, 384, 3) olmalıdır. Resmi ortalayarak kırpıyoruz ve yeniden boyutlandırıyoruz.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Girişleri görselleştirin

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

TensorFlow Lite ile stil aktarımını çalıştırın

Stil tahmini

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

stil dönüşümü

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Stil karıştırma

İçerik görüntüsünün stilini stilize edilmiş çıktıyla harmanlayabiliriz, bu da çıktının içerik görüntüsüne daha çok benzemesini sağlar.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Performans Kıyaslamaları

Performans kriter numaraları aracıyla oluşturulan Burada anlatılan .

Model adı Model boyutu Cihaz NNAPI İşlemci GPU
Stil tahmin modeli (int8) 2,8 Mb Piksel 3 (Android 10) 142 ms 14ms
Piksel 4 (Android 10) 5.2 ms 6.7ms
iPhone XS (iOS 12.4.1) 10.7ms
Stil dönüştürme modeli (int8) 0,2 Mb Piksel 3 (Android 10) 540ms
Piksel 4 (Android 10) 405ms
iPhone XS (iOS 12.4.1) 251 ms
Stil tahmin modeli (float16) 4.7 Mb Piksel 3 (Android 10) 86ms 28ms 9.1 ms
Piksel 4 (Android 10) 32ms 12ms 10ms
Stil aktarım modeli (float16) 0,4 Mb Piksel 3 (Android 10) 1095ms 545ms 42ms
Piksel 4 (Android 10) 603 ms 377ms 42ms

* 4 iplik kullanılmıştır.
** En iyi performans için iPhone'da 2 konu.