Tarihi kaydet! Google I / O 18-20 Mayıs'ta geri dönüyor Şimdi kaydolun
Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

TensorFlow Lite ile Sanatsal Stil Transferi

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin TF Hub modeline bakın

Derin öğrenmede son zamanlarda ortaya çıkan en heyecan verici gelişmelerden biri, sanatsal stil aktarımı veya pastiş olarak bilinen, biri sanatsal stili temsil eden diğeri içeriği temsil eden iki girdi resmine dayalı yeni bir imaj yaratma yeteneğidir.

Stil aktarımı örneği

Bu tekniği kullanarak çeşitli tarzlarda güzel yeni sanat eserleri üretebiliriz.

Stil aktarımı örneği

TensorFlow Lite'ta yeniyseniz ve Android ile çalışıyorsanız, başlamanıza yardımcı olabilecek aşağıdaki örnek uygulamaları incelemenizi öneririz.

Android örneği iOS örneği

Android veya iOS dışında bir platform kullanıyorsanız veya TensorFlow Lite API'lerine aşina iseniz, önceden eğitilmiş bir TensorFlow Lite ile herhangi bir içerik çiftine ve stil görüntüsüne stil aktarımını nasıl uygulayacağınızı öğrenmek için bu öğreticiyi takip edebilirsiniz. model. Modeli kendi mobil uygulamalarınıza stil aktarımı eklemek için kullanabilirsiniz.

Model, GitHub'da açık kaynaklı. Modeli farklı parametrelerle yeniden eğitebilirsiniz (örneğin, çıktı görüntüsünün içerik görüntüsüne daha çok benzemesi için içerik katmanlarının ağırlıklarını artırın).

Model mimarisini anlayın

Model Mimarisi

Bu Sanatsal Stil Transferi modeli iki alt modelden oluşur:

  1. Stil Öngörme Modeli : Giriş stili görüntüyü 100 boyutlu bir stil darboğaz vektörüne götüren MobilenetV2 tabanlı bir sinir ağı.
  2. Stil Dönüşümü Modeli : Bir içerik görüntüsüne stil darboğaz vektörü uygulayan ve stilize edilmiş bir görüntü oluşturan bir sinir ağı.

Uygulamanızın yalnızca sabit bir stil görüntülerini desteklemesi gerekiyorsa, bunların stil darboğaz vektörlerini önceden hesaplayabilir ve Stil Tahmin Modelini uygulamanızın ikili dosyasından hariç tutabilirsiniz.

Kurulum

Bağımlılıkları içe aktarın.

import tensorflow as tf
print(tf.__version__)
2.4.1
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

İçerik ve stil görsellerini ve önceden eğitilmiş TensorFlow Lite modellerini indirin.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step

Girişleri önceden işleyin

  • İçerik resmi ve stil resmi, piksel değerleri [0..1] arasında float32 sayı olan RGB resimler olmalıdır.
  • Stil resmi boyutu (1, 256, 256, 3) olmalıdır. Görüntüyü merkezden kırpıp yeniden boyutlandırıyoruz.
  • İçerik resmi (1, 384, 384, 3) olmalıdır. Resmi merkezden kırpıp yeniden boyutlandırıyoruz.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Girişleri görselleştirin

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

TensorFlow Lite ile stil aktarımı çalıştırın

Stil tahmini

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Stil dönüşümü

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Stil karıştırma

İçerik resminin stilini stilize çıktıyla harmanlayarak çıktının daha çok içerik resmine benzemesini sağlayabiliriz.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Performans Karşılaştırmaları

Performans karşılaştırma numaraları, burada açıklanan araçla oluşturulur.

Model adı Model boyutu cihaz NNAPI İşlemci GPU
Stil tahmin modeli (int8) 2.8 Mb Pixel 3 (Android 10) 142 ms 14 ms
Pixel 4 (Android 10) 5,2 ms 6,7 ms
iPhone XS (iOS 12.4.1) 10,7 ms
Stil dönüşümü modeli (int8) 0.2 Mb Pixel 3 (Android 10) 540 ms
Pixel 4 (Android 10) 405 ms
iPhone XS (iOS 12.4.1) 251 ms
Stil tahmin modeli (float16) 4.7 Mb Pixel 3 (Android 10) 86 ms 28 ms 9,1 ms
Pixel 4 (Android 10) 32 ms 12 ms 10 ms
Stil aktarım modeli (float16) 0.4 Mb Pixel 3 (Android 10) 1095 ms 545 ms 42 ms
Pixel 4 (Android 10) 603 ms 377 ms 42 ms

* 4 iplik kullanıldı.
** En iyi performans için iPhone'da 2 iş parçacığı.