Truyền phong cách nghệ thuật với TensorFlow Lite

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép Xem mô hình TF Hub

Một trong những diễn biến thú vị nhất trong việc học sâu để đi ra thời gian gần đây là chuyển nghệ thuật phong cách , hoặc khả năng để tạo ra một hình ảnh mới, được gọi là một tác phẩm mô phỏng , dựa trên hai hình ảnh đầu vào: một đại diện cho phong cách nghệ thuật và một đại diện cho nội dung.

Ví dụ về chuyển kiểu

Sử dụng kỹ thuật này, chúng tôi có thể tạo ra các tác phẩm nghệ thuật mới tuyệt đẹp theo nhiều phong cách.

Ví dụ về chuyển kiểu

Nếu bạn chưa quen với TensorFlow Lite và đang làm việc với Android, chúng tôi khuyên bạn nên khám phá các ứng dụng mẫu sau có thể giúp bạn bắt đầu.

Ví dụ Android iOS dụ

Nếu bạn đang sử dụng một nền tảng khác hơn so với Android hay iOS, hoặc bạn đã quen thuộc với các Lite API TensorFlow , bạn có thể làm theo hướng dẫn này để tìm hiểu làm thế nào để áp dụng chuyển phong cách trên bất kỳ cặp nội dung và hình ảnh phong cách với một pre-đào tạo TensorFlow Lite mô hình. Bạn có thể sử dụng mô hình để thêm chuyển kiểu vào các ứng dụng di động của riêng mình.

Mô hình này là mã nguồn mở trên GitHub . Bạn có thể đào tạo lại mô hình với các tham số khác nhau (ví dụ: tăng trọng số của các lớp nội dung để làm cho hình ảnh đầu ra giống hình ảnh nội dung hơn).

Hiểu kiến ​​trúc mô hình

Kiến trúc mô hình

Mô hình Chuyển giao phong cách nghệ thuật này bao gồm hai mô hình con:

  1. Phong cách Prediciton Model: A MobilenetV2 dựa trên mạng lưới thần kinh mà phải mất một hình ảnh phong cách đầu vào cho một 100-chiều vector phong cách cổ chai.
  2. Phong cách chuyển đổi mô hình: Một mạng lưới thần kinh mà có áp dụng một vector phong cách cổ chai để một hình ảnh nội dung và tạo ra một hình ảnh cách điệu.

Nếu ứng dụng của bạn chỉ cần hỗ trợ một tập hợp hình ảnh kiểu cố định, bạn có thể tính toán trước các vectơ kiểu cổ chai của chúng và loại trừ Mô hình dự đoán kiểu khỏi tệp nhị phân của ứng dụng.

Thành lập

Nhập phụ thuộc.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

Tải xuống nội dung và hình ảnh kiểu cũng như các mô hình TensorFlow Lite được đào tạo trước.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

Xử lý trước các đầu vào

  • Hình ảnh nội dung và hình ảnh kiểu phải là hình ảnh RGB với giá trị pixel là số float32 trong khoảng [0..1].
  • Kích thước hình ảnh kiểu phải là (1, 256, 256, 3). Chúng tôi cắt hình ảnh trung tâm và thay đổi kích thước.
  • Hình ảnh nội dung phải là (1, 384, 384, 3). Chúng tôi cắt hình ảnh trung tâm và thay đổi kích thước.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

Hình dung các đầu vào

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

Chạy chuyển kiểu với TensorFlow Lite

Dự đoán phong cách

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

Chuyển đổi phong cách

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

Pha trộn phong cách

Chúng ta có thể trộn kiểu của hình ảnh nội dung vào đầu ra cách điệu, do đó làm cho đầu ra trông giống hình ảnh nội dung hơn.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

Điểm chuẩn hiệu suất

Số benchmark hiệu năng được tạo ra với công cụ mô tả ở đây .

Tên mô hình Kích thước mô hình Thiết bị NNAPI CPU GPU
Mô hình dự đoán kiểu (int8) 2,8 Mb Pixel 3 (Android 10) 142ms 14ms
Pixel 4 (Android 10) 5,2ms 6,7 mili giây
iPhone XS (iOS 12.4.1) 10,7 mili giây
Mô hình biến đổi kiểu (int8) 0,2 Mb Pixel 3 (Android 10) 540ms
Pixel 4 (Android 10) 405ms
iPhone XS (iOS 12.4.1) 251ms
Mô hình dự đoán kiểu (float16) 4,7 Mb Pixel 3 (Android 10) 86ms 28ms 9.1ms
Pixel 4 (Android 10) 32ms 12ms 10ms
Mô hình chuyển kiểu (float16) 0,4 Mb Pixel 3 (Android 10) 1095ms 545ms 42ms
Pixel 4 (Android 10) 603ms 377ms 42ms

* 4 chủ đề được sử dụng.
** 2 luồng trên iPhone để có hiệu suất tốt nhất.