Catat tanggalnya! Google I / O mengembalikan 18-20 Mei Daftar sekarang
Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Transfer Gaya Artistik dengan TensorFlow Lite

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan Lihat model TF Hub

Salah satu perkembangan paling menarik dalam deep learning yang akan dirilis baru-baru ini adalah transfer gaya artistik , atau kemampuan untuk membuat gambar baru, yang disebut bunga pastiche , berdasarkan dua gambar masukan: satu mewakili gaya artistik dan satu lagi mewakili konten.

Contoh transfer gaya

Dengan menggunakan teknik ini, kita dapat menghasilkan karya seni baru yang indah dalam berbagai gaya.

Contoh transfer gaya

Jika Anda baru mengenal TensorFlow Lite dan menggunakan Android, sebaiknya jelajahi aplikasi contoh berikut yang dapat membantu Anda memulai.

Contoh Android Contoh iOS

Jika Anda menggunakan platform selain Android atau iOS, atau Anda sudah terbiasa dengan TensorFlow Lite API , Anda dapat mengikuti tutorial ini untuk mempelajari cara menerapkan transfer gaya pada pasangan konten dan gambar gaya apa pun dengan TensorFlow Lite terlatih. model. Anda dapat menggunakan model tersebut untuk menambahkan transfer gaya ke aplikasi seluler Anda sendiri.

Model ini bersumber terbuka di GitHub . Anda dapat melatih ulang model dengan parameter yang berbeda (mis. Menambah bobot lapisan konten untuk membuat gambar keluaran lebih terlihat seperti gambar konten).

Pahami arsitektur model

Arsitektur Model

Model Transfer Gaya Artistik ini terdiri dari dua submodel:

  1. Model Prediksi Gaya : Jaringan saraf berbasis MobilenetV2 yang membawa gambar gaya masukan ke vektor hambatan gaya 100 dimensi.
  2. Model Transformasi Gaya : Jaringan saraf yang menerapkan vektor penghambat gaya ke gambar konten dan membuat gambar bergaya.

Jika aplikasi Anda hanya perlu mendukung sekumpulan tetap gambar gaya, Anda bisa menghitung vektor penghambat gayanya terlebih dahulu, dan mengecualikan Model Prediksi Gaya dari biner aplikasi Anda.

Mempersiapkan

Ketergantungan impor.

import tensorflow as tf
print(tf.__version__)
2.4.1
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Download gambar gaya dan konten, serta model TensorFlow Lite terlatih.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step

Pra-proses masukan

  • Gambar konten dan gambar gaya harus berupa gambar RGB dengan nilai piksel berupa angka float32 antara [0..1].
  • Ukuran gambar gaya harus (1, 256, 256, 3). Kami memangkas gambar dan mengubah ukurannya.
  • Gambar konten harus (1, 384, 384, 3). Kami memusatkan gambar dan mengubah ukurannya.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Visualisasikan masukan

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Jalankan transfer gaya dengan TensorFlow Lite

Prediksi gaya

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Transformasi gaya

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Pencampuran gaya

Kita dapat memadukan gaya gambar konten ke dalam keluaran bergaya, yang pada gilirannya membuat keluaran lebih terlihat seperti gambar konten.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Tolok Ukur Kinerja

Angka tolok ukur kinerja dibuat dengan alat yang dijelaskan di sini .

Nama model Ukuran model Alat NNAPI CPU GPU
Model prediksi gaya (int8) 2.8 Mb Pixel 3 (Android 10) 142 md 14 md
Pixel 4 (Android 10) 5,2 md 6,7 md
iPhone XS (iOS 12.4.1) 10,7 md
Model transformasi gaya (int8) 0,2 Mb Pixel 3 (Android 10) 540 md
Pixel 4 (Android 10) 405 md
iPhone XS (iOS 12.4.1) 251 md
Model prediksi gaya (float16) 4.7 Mb Pixel 3 (Android 10) 86 md 28 md 9,1 md
Pixel 4 (Android 10) 32 md 12 md 10 md
Model transfer gaya (float16) 0.4 Mb Pixel 3 (Android 10) 1095 md 545 md 42 md
Pixel 4 (Android 10) 603 md 377 md 42 md

* 4 utas digunakan.
** 2 utas di iPhone untuk kinerja terbaik.