การถ่ายโอนรูปแบบศิลปะด้วย TensorFlow Lite

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก ดูโมเดล TF Hub

พัฒนาการที่น่าตื่นเต้นที่สุดอย่างหนึ่งในการเรียนรู้เชิงลึกที่จะออกมาเมื่อเร็ว ๆ นี้คือ การถ่ายทอดสไตล์ศิลปะ หรือความสามารถในการสร้างภาพใหม่ที่เรียกว่า Pastiche โดยใช้ภาพอินพุต 2 ภาพ: ภาพหนึ่งแสดงถึงรูปแบบศิลปะและอีกภาพหนึ่งแสดงถึงเนื้อหา

ตัวอย่างการถ่ายโอนสไตล์

ด้วยเทคนิคนี้เราสามารถสร้างงานศิลปะใหม่ ๆ ที่สวยงามได้ในหลากหลายสไตล์

ตัวอย่างการถ่ายโอนสไตล์

หากคุณยังใหม่กับ TensorFlow Lite และกำลังทำงานกับ Android เราขอแนะนำให้สำรวจแอปพลิเคชันตัวอย่างต่อไปนี้ที่สามารถช่วยคุณในการเริ่มต้นได้

ตัวอย่าง Android ตัวอย่าง iOS

หากคุณใช้แพลตฟอร์มอื่นที่ไม่ใช่ Android หรือ iOS หรือคุณคุ้นเคยกับ TensorFlow Lite API อยู่ แล้วคุณสามารถทำตามบทช่วยสอนนี้เพื่อเรียนรู้วิธีใช้การถ่ายโอนสไตล์กับคู่ของเนื้อหาและภาพสไตล์ด้วย TensorFlow Lite ที่ผ่านการฝึกอบรมมาก่อน แบบ. คุณสามารถใช้โมเดลเพื่อเพิ่มการถ่ายโอนสไตล์ไปยังแอปพลิเคชันมือถือของคุณเอง

โมเดลนี้เป็นแบบโอเพนซอร์สบน GitHub คุณสามารถฝึกโมเดลใหม่ด้วยพารามิเตอร์ที่แตกต่างกัน (เช่นเพิ่มน้ำหนักของเลเยอร์เนื้อหาเพื่อให้รูปภาพที่ส่งออกดูเหมือนกับรูปภาพเนื้อหามากขึ้น)

ทำความเข้าใจเกี่ยวกับสถาปัตยกรรมของโมเดล

สถาปัตยกรรมแบบจำลอง

โมเดลการถ่ายโอนสไตล์ศิลปะนี้ประกอบด้วยโมเดลย่อยสองแบบ:

  1. รูปแบบ Prediciton สไตล์ : เครือข่ายประสาทเทียมที่ใช้ MobilenetV2 ซึ่งนำภาพสไตล์อินพุตไปเป็นเวกเตอร์คอขวดสไตล์ 100 มิติ
  2. รูปแบบการแปลงสไตล์ : เครือข่ายประสาทที่ใช้เวกเตอร์คอขวดสไตล์กับรูปภาพเนื้อหาและสร้างภาพที่มีสไตล์

หากแอปของคุณต้องการเพียงการสนับสนุนชุดรูปภาพสไตล์คงที่คุณสามารถคำนวณเวกเตอร์คอขวดสไตล์ของพวกเขาล่วงหน้าและแยกรูปแบบการคาดเดาสไตล์ออกจากไบนารีของแอปของคุณได้

ติดตั้ง

นำเข้าการอ้างอิง

import tensorflow as tf
print(tf.__version__)
2.5.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

ดาวน์โหลดเนื้อหาและรูปแบบรูปภาพและโมเดล TensorFlow Lite ที่ได้รับการฝึกฝนมาก่อน

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step

ประมวลผลอินพุตล่วงหน้า

  • รูปภาพเนื้อหาและสไตล์อิมเมจต้องเป็นภาพ RGB ที่มีค่าพิกเซลเป็นตัวเลข float32 ระหว่าง [0..1]
  • ขนาดรูปภาพสไตล์ต้องเป็น (1, 256, 256, 3) เราครอบตัดรูปภาพเป็นศูนย์กลางและปรับขนาด
  • รูปภาพของเนื้อหาต้องเป็น (1, 384, 384, 3) เราครอบตัดรูปภาพเป็นศูนย์กลางและปรับขนาด
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

แสดงภาพปัจจัยการผลิต

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

เรียกใช้การถ่ายโอนสไตล์ด้วย TensorFlow Lite

การทำนายสไตล์

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

เปลี่ยนสไตล์

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

การผสมผสานสไตล์

เราสามารถผสมผสานรูปแบบของภาพเนื้อหาลงในเอาต์พุตที่มีสไตล์ซึ่งจะทำให้ผลลัพธ์มีลักษณะเหมือนภาพเนื้อหามากขึ้น

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

เกณฑ์มาตรฐานประสิทธิภาพ

ตัวเลขมาตรฐานประสิทธิภาพถูกสร้างขึ้นด้วยเครื่องมือที่ อธิบายไว้ที่นี่

ชื่อรุ่น ขนาดโมเดล อุปกรณ์ NNAPI ซีพียู GPU
รูปแบบการทำนายสไตล์ (int8) 2.8 ลบ พิกเซล 3 (Android 10) 142 มิลลิวินาที 14 มิลลิวินาที
พิกเซล 4 (Android 10) 5.2 มิลลิวินาที 6.7 มิลลิวินาที
iPhone XS (iOS 12.4.1) 10.7 มิลลิวินาที
รูปแบบการแปลงสไตล์ (int8) 0.2 ล้านบาท พิกเซล 3 (Android 10) 540 มิลลิวินาที
พิกเซล 4 (Android 10) 405 มิลลิวินาที
iPhone XS (iOS 12.4.1) 251 มิลลิวินาที
รูปแบบการทำนายสไตล์ (float16) 4.7 ล้านบาท พิกเซล 3 (Android 10) 86 มิลลิวินาที 28 มิลลิวินาที 9.1 มิลลิวินาที
พิกเซล 4 (Android 10) 32 มิลลิวินาที 12 มิลลิวินาที 10 มิลลิวินาที
รูปแบบการถ่ายโอนสไตล์ (float16) 0.4 เมกะไบต์ พิกเซล 3 (Android 10) 1095 มิลลิวินาที 545 มิลลิวินาที 42 มิลลิวินาที
พิกเซล 4 (Android 10) 603 มิลลิวินาที 377 มิลลิวินาที 42 มิลลิวินาที

* 4 เธรดที่ใช้
** 2 เธรดบน iPhone เพื่อประสิทธิภาพที่ดีที่สุด