העברת סגנון אמנותי עם TensorFlow Lite

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת ראה דגם TF Hub

אחת ההתפתחויות המלהיבות ביותר בלמידה עמוקה לצאת הוא לאחרונה העברת סגנון אמנותית , או את היכולת ליצור תמונה חדשה, המכונית פסטיש , המבוססת על שתי תמונות קלט: אחד המייצג את הסגנון האמנותי ואחד המייצגת את התוכן.

דוגמה להעברת סגנון

באמצעות טכניקה זו, נוכל ליצור יצירות אמנות חדשות ויפות במגוון סגנונות.

דוגמה להעברת סגנון

אם אתה חדש ב-TensorFlow Lite ועובד עם אנדרואיד, אנו ממליצים לחקור את האפליקציות לדוגמה הבאות שיכולות לעזור לך להתחיל.

אנדרואיד דוגמה דוגמה iOS

אם אתה משתמש בפלטפורמה אחרת מאשר אנדרואיד או iOS, או שאתה כבר מכיר את APIs לייט TensorFlow , אתה יכול לעקוב במדריך זה כדי ללמוד כיצד ליישם העברת בסגנון על כל זוג תוכן ותמונה בסטייל עם לייט מראש מאומן TensorFlow דֶגֶם. אתה יכול להשתמש במודל כדי להוסיף העברת סגנון ליישומים ניידים משלך.

המודל הוא קוד מקור פתוח על GitHub . אתה יכול לאמן מחדש את המודל עם פרמטרים שונים (למשל להגדיל את משקל שכבות התוכן כדי לגרום לתמונת הפלט להיראות יותר כמו תמונת התוכן).

הבן את ארכיטקטורת המודל

אדריכלות מודל

מודל העברת סגנון אמנותי זה מורכב משני דגמי משנה:

  1. סגנון Prediciton דגם: A MobilenetV2 מבוסס רשת עצבית שלוקחת תמונת סגנון קלט כדי וקטור צוואר בקבוק בסגנון 100-ממד.
  2. סגנון Transform דגם: רשת עצבית שלוקחת להחיל וקטור צוואר בקבוק בסגנון לתמונת תוכן ויוצרת תמונה מסוגננת.

אם האפליקציה שלך צריכה לתמוך רק בסט קבוע של תמונות סגנון, תוכל לחשב את וקטורי צוואר הבקבוק של הסגנון שלהן מראש, ולא לכלול את מודל חיזוי הסגנון מהקובץ הבינארי של האפליקציה שלך.

להכין

תלות בייבוא.

import tensorflow as tf
print(tf.__version__)
2.6.0
import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

הורד את תמונות התוכן והסגנון, ואת דגמי TensorFlow Lite שהוכשרו מראש.

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
466944/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
122880/108525 [=================================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
2842624/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step
294912/284398 [===============================] - 0s 0us/step

עבדו מראש את התשומות

  • תמונת התוכן ותמונת הסגנון חייבות להיות תמונות RGB כאשר ערכי הפיקסלים הם מספרי float32 בין [0..1].
  • גודל תמונת הסגנון חייב להיות (1, 256, 256, 3). אנו חותכים את התמונה במרכז ומשנים את גודלה.
  • תמונת התוכן חייבת להיות (1, 384, 384, 3). אנו חותכים את התמונה במרכז ומשנים את גודלה.
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

דמיינו את התשומות

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

הפעל העברת סגנון עם TensorFlow Lite

חיזוי סגנון

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

שינוי סגנון

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

מיזוג סגנון

אנחנו יכולים למזג את סגנון תמונת התוכן לתוך הפלט המסוגנן, מה שבתורו גורם לפלט להיראות יותר כמו תמונת התוכן.

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \
                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

מדדי ביצועים

מספרי benchmark ביצועים נוצרים עם הכלי המתואר כאן .

שם המודל גודל הדגם התקן NNAPI מעבד GPU
מודל חיזוי סגנון (int8) 2.8 מגה-ביט Pixel 3 (אנדרואיד 10) 142ms 14 אלפיות השנייה
Pixel 4 (אנדרואיד 10) 5.2 אלפיות השנייה 6.7 אלפיות השנייה
iPhone XS (iOS 12.4.1) 10.7 אלפיות השנייה
מודל שינוי סגנון (int8) 0.2 מגה-ביט Pixel 3 (אנדרואיד 10) 540ms
Pixel 4 (אנדרואיד 10) 405ms
iPhone XS (iOS 12.4.1) 251ms
מודל חיזוי סגנון (float16) 4.7 מגה-ביט Pixel 3 (אנדרואיד 10) 86ms 28ms 9.1 אלפיות השנייה
Pixel 4 (אנדרואיד 10) 32ms 12ms 10 אלפיות השנייה
דגם העברת סגנון (float16) 0.4 מגה-ביט Pixel 3 (אנדרואיד 10) 1095ms 545ms 42ms
Pixel 4 (אנדרואיד 10) 603ms 377ms 42ms

* 4 חוטים בשימוש.
** 2 שרשורים באייפון לביצועים הטובים ביותר.