このページは Cloud Translation API によって翻訳されました。
Switch to English

TensorFlow Liteによる芸術的なスタイルの転送

TensorFlow.orgで表示 Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

最近登場するディープラーニングで最もエキサイティングな開発の1つは、 芸術的なスタイルの転送 、または2つの入力画像に基づいて、 パスティッシュと呼ばれる新しい画像を作成する機能です。1つは芸術的なスタイルを表し、もう1つはコンテンツを表します。

スタイル転送の例

この手法を使用して、さまざまなスタイルで美しい新しいアートワークを生成できます。

スタイル転送の例

TensorFlow Liteを初めて使用し、Androidを使用している場合は、開始に役立つ次のサンプルアプリケーションを探すことをお勧めします。

Androidの例 iOSの例

AndroidまたはiOS以外のプラットフォームを使用している場合、またはTensorFlow Lite APIに既に精通している場合は、このチュートリアルに従って、事前トレーニング済みのTensorFlow Liteを使用してコンテンツとスタイル画像のペアにスタイル転送を適用する方法を学ぶことができますモデル。モデルを使用して、独自のモバイルアプリケーションにスタイル転送を追加できます。

モデルはGitHubでオープンソース化されています 。さまざまなパラメーターでモデルを再トレーニングできます(たとえば、コンテンツレイヤーの重みを増やして、出力画像をコンテンツ画像のように見せます)。

モデルアーキテクチャを理解する

モデルアーキテクチャ

このArtistic Style Transferモデルは、2つのサブモデルで構成されています。

  1. スタイル予測モデル :入力スタイル画像を100次元スタイルのボトルネックベクトルに変換するMobilenetV2ベースのニューラルネットワーク。
  2. スタイル変換モデル :スタイルのボトルネックベクトルをコンテンツ画像に適用し、様式化された画像を作成するニューラルネットワーク。

アプリでスタイルイメージの固定セットのみをサポートする必要がある場合は、事前にスタイルボトルネックベクトルを計算し、アプリのバイナリからスタイル予測モデルを除外できます。

セットアップ

依存関係をインポートします。

import tensorflow as tf
print(tf.__version__)
2.3.0

import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

コンテンツとスタイルの画像、事前トレーニング済みのTensorFlow Liteモデルをダウンロードします。

content_path = tf.keras.utils.get_file('belfry.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg')
style_path = tf.keras.utils.get_file('style23.jpg','https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg')

style_predict_path = tf.keras.utils.get_file('style_predict.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite')
style_transform_path = tf.keras.utils.get_file('style_transform.tflite', 'https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite')
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/belfry-2611573_1280.jpg
458752/458481 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/khanhlvg-public.appspot.com/arbitrary-style-transfer/style23.jpg
114688/108525 [===============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite
2834432/2828838 [==============================] - 0s 0us/step
Downloading data from https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/transfer/1?lite-format=tflite
286720/284398 [==============================] - 0s 0us/step

入力を前処理する

  • コンテンツ画像とスタイル画像は、ピクセル値が[0..1]の間のfloat32の数値であるRGB画像である必要があります。
  • スタイル画像のサイズは(1、256、256、3)でなければなりません。画像を中央でトリミングしてサイズを変更します。
  • コンテンツ画像は(1、384、384、3)でなければなりません。画像を中央でトリミングしてサイズを変更します。
# Function to load an image from a file, and add a batch dimension.
def load_img(path_to_img):
  img = tf.io.read_file(path_to_img)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = img[tf.newaxis, :]

  return img

# Function to pre-process by resizing an central cropping it.
def preprocess_image(image, target_dim):
  # Resize the image so that the shorter dimension becomes 256px.
  shape = tf.cast(tf.shape(image)[1:-1], tf.float32)
  short_dim = min(shape)
  scale = target_dim / short_dim
  new_shape = tf.cast(shape * scale, tf.int32)
  image = tf.image.resize(image, new_shape)

  # Central crop the image.
  image = tf.image.resize_with_crop_or_pad(image, target_dim, target_dim)

  return image

# Load the input images.
content_image = load_img(content_path)
style_image = load_img(style_path)

# Preprocess the input images.
preprocessed_content_image = preprocess_image(content_image, 384)
preprocessed_style_image = preprocess_image(style_image, 256)

print('Style Image Shape:', preprocessed_style_image.shape)
print('Content Image Shape:', preprocessed_content_image.shape)
Style Image Shape: (1, 256, 256, 3)
Content Image Shape: (1, 384, 384, 3)

入力を視覚化する

def imshow(image, title=None):
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

plt.subplot(1, 2, 1)
imshow(preprocessed_content_image, 'Content Image')

plt.subplot(1, 2, 2)
imshow(preprocessed_style_image, 'Style Image')

png

TensorFlow Liteでスタイル転送を実行する

スタイル予測

# Function to run style prediction on preprocessed style image.
def run_style_predict(preprocessed_style_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_predict_path)

  # Set model input.
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  interpreter.set_tensor(input_details[0]["index"], preprocessed_style_image)

  # Calculate style bottleneck.
  interpreter.invoke()
  style_bottleneck = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(preprocessed_style_image)
print('Style Bottleneck Shape:', style_bottleneck.shape)
Style Bottleneck Shape: (1, 1, 1, 100)

スタイル変換

# Run style transform on preprocessed style image
def run_style_transform(style_bottleneck, preprocessed_content_image):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=style_transform_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_content_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Stylize the content image using the style bottleneck.
stylized_image = run_style_transform(style_bottleneck, preprocessed_content_image)

# Visualize the output.
imshow(stylized_image, 'Stylized Image')

png

スタイルブレンディング

コンテンツ画像のスタイルを様式化された出力にブレンドして、出力をコンテンツ画像のように見せることができます。

# Calculate style bottleneck of the content image.
style_bottleneck_content = run_style_predict(
    preprocess_image(content_image, 256)
    )
# Define content blending ratio between [0..1].
# 0.0: 0% style extracts from content image.
# 1.0: 100% style extracted from content image.
content_blending_ratio = 0.5 

# Blend the style bottleneck of style image and content image
style_bottleneck_blended = content_blending_ratio * style_bottleneck_content \

                           + (1 - content_blending_ratio) * style_bottleneck

# Stylize the content image using the style bottleneck.
stylized_image_blended = run_style_transform(style_bottleneck_blended,
                                             preprocessed_content_image)

# Visualize the output.
imshow(stylized_image_blended, 'Blended Stylized Image')

png

パフォーマンスベンチマーク

ここで説明するツールを使用して、パフォーマンスベンチマークの数値が生成されます

モデル名モデルサイズ端末 NNAPI CPU GPU
スタイル予測モデル(int8) 2.8 Mb Pixel 3(Android 10) 142ミリ秒 14ms
Pixel 4(Android 10) 5.2ms 6.7ms
iPhone XS(iOS 12.4.1) 10.7ms
スタイル変換モデル(int8) 0.2 Mb Pixel 3(Android 10) 540ms
Pixel 4(Android 10) 405ms
iPhone XS(iOS 12.4.1) 251ms
スタイル予測モデル(float16) 4.7 Mb Pixel 3(Android 10) 86ms 28ms 9.1ms
Pixel 4(Android 10) 32ms 12ms 10ms
スタイル転送モデル(float16) 0.4 Mb Pixel 3(Android 10) 1095ms 545ms 42ms
Pixel 4(Android 10) 603ミリ秒 377ms 42ms

* 4スレッドが使用されます。
** iPhoneでは2スレッドで最高のパフォーマンスを発揮します。