![]() | ![]() | ![]() | ![]() | ![]() |
このチュートリアルでは、ディープラーニングを使用して、ある画像を別の画像のスタイルで作成します(ピカソやゴッホのようにペイントしたいですか?)。これはニューラルスタイル転送として知られており、この手法はアーティスティックスタイルのニューラルアルゴリズム(Gatys et al。)に概説されています。
スタイル転送の簡単なアプリケーションでは、このチェックアウトチュートリアルをよりpretrained使用方法について学ぶために任意の画像様式化モデルからTensorFlowハブまたはどのようにしてスタイルの伝達モデルを使用するTensorFlow Liteの。
ニューラルスタイル転送は、コンテンツ画像とスタイル参照画像(有名な画家によるアートワークなど)の2つの画像を取得し、それらをブレンドして、出力画像がコンテンツ画像のように見えるが「ペイント」されるようにするために使用される最適化手法です。スタイル参照画像のスタイルで。
これは、コンテンツ画像のコンテンツ統計とスタイル参照画像のスタイル統計に一致するように出力画像を最適化することによって実装されます。これらの統計は、畳み込みネットワークを使用して画像から抽出されます。
たとえば、この犬とWassilyKandinskyのComposition7の画像を見てみましょう。
黄色いラブラドール探し、エルフによるウィキメディアコモンズから。ライセンスCCBY-SA 3.0
カンディンスキーがこの犬の絵をこのスタイルだけで描くことにしたとしたら、どのように見えるでしょうか?このようなもの?
セットアップ
モジュールのインポートと構成
import os
import tensorflow as tf
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
import IPython.display as display
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False
import numpy as np
import PIL.Image
import time
import functools
def tensor_to_image(tensor):
tensor = tensor*255
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor)>3:
assert tensor.shape[0] == 1
tensor = tensor[0]
return PIL.Image.fromarray(tensor)
画像をダウンロードし、スタイル画像とコンテンツ画像を選択します。
content_path = tf.keras.utils.get_file('YellowLabradorLooking_new.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg')
style_path = tf.keras.utils.get_file('kandinsky5.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg 90112/83281 [================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg 196608/195196 [==============================] - 0s 0us/step
入力を視覚化する
画像をロードし、その最大サイズを512ピクセルに制限する関数を定義します。
def load_img(path_to_img):
max_dim = 512
img = tf.io.read_file(path_to_img)
img = tf.image.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
shape = tf.cast(tf.shape(img)[:-1], tf.float32)
long_dim = max(shape)
scale = max_dim / long_dim
new_shape = tf.cast(shape * scale, tf.int32)
img = tf.image.resize(img, new_shape)
img = img[tf.newaxis, :]
return img
画像を表示する簡単な関数を作成します。
def imshow(image, title=None):
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
if title:
plt.title(title)
content_image = load_img(content_path)
style_image = load_img(style_path)
plt.subplot(1, 2, 1)
imshow(content_image, 'Content Image')
plt.subplot(1, 2, 2)
imshow(style_image, 'Style Image')
TFハブを使用した高速スタイル転送
このチュートリアルでは、画像コンテンツを特定のスタイルに最適化する、元のスタイル転送アルゴリズムを示します。詳細に入る前に、 TensorFlowハブモデルがこれをどのように行うかを見てみましょう。
import tensorflow_hub as hub
hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
tensor_to_image(stylized_image)
コンテンツとスタイルの表現を定義する
モデルの中間レイヤーを使用して、画像のコンテンツとスタイルの表現を取得します。ネットワークの入力レイヤーから始めて、最初のいくつかのレイヤーのアクティブ化は、エッジやテクスチャなどの低レベルの機能を表します。ネットワークをステップスルーすると、最後のいくつかのレイヤーは、より高いレベルの機能(ホイールや目などのオブジェクトパーツ)を表します。この場合、事前にトレーニングされた画像分類ネットワークであるVGG19ネットワークアーキテクチャを使用しています。これらの中間レイヤーは、画像からコンテンツとスタイルの表現を定義するために必要です。入力画像の場合、これらの中間レイヤーで対応するスタイルとコンテンツターゲットの表現を一致させてみてください。
VGG19をロードし、イメージでテスト実行して、正しく使用されていることを確認します。
x = tf.keras.applications.vgg19.preprocess_input(content_image*255)
x = tf.image.resize(x, (224, 224))
vgg = tf.keras.applications.VGG19(include_top=True, weights='imagenet')
prediction_probabilities = vgg(x)
prediction_probabilities.shape
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5 574717952/574710816 [==============================] - 3s 0us/step TensorShape([1, 1000])
predicted_top_5 = tf.keras.applications.vgg19.decode_predictions(prediction_probabilities.numpy())[0]
[(class_name, prob) for (number, class_name, prob) in predicted_top_5]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json 40960/35363 [==================================] - 0s 0us/step [('Labrador_retriever', 0.49317107), ('golden_retriever', 0.23665291), ('kuvasz', 0.03635751), ('Chesapeake_Bay_retriever', 0.024182765), ('Greater_Swiss_Mountain_dog', 0.018646102)]
次に、分類ヘッドなしでVGG19
をロードし、レイヤー名をリストします。
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
print()
for layer in vgg.layers:
print(layer.name)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5 80142336/80134624 [==============================] - 2s 0us/step input_2 block1_conv1 block1_conv2 block1_pool block2_conv1 block2_conv2 block2_pool block3_conv1 block3_conv2 block3_conv3 block3_conv4 block3_pool block4_conv1 block4_conv2 block4_conv3 block4_conv4 block4_pool block5_conv1 block5_conv2 block5_conv3 block5_conv4 block5_pool
ネットワークから中間レイヤーを選択して、画像のスタイルとコンテンツを表します。
content_layers = ['block5_conv2']
style_layers = ['block1_conv1',
'block2_conv1',
'block3_conv1',
'block4_conv1',
'block5_conv1']
num_content_layers = len(content_layers)
num_style_layers = len(style_layers)
スタイルとコンテンツの中間レイヤー
では、事前にトレーニングされた画像分類ネットワーク内のこれらの中間出力により、スタイルとコンテンツの表現を定義できるのはなぜですか?
大まかに言えば、ネットワークが画像分類(このネットワークが行うように訓練されている)を実行するためには、画像を理解する必要があります。これには、生の画像を入力ピクセルとして取得し、生の画像のピクセルを画像内に存在する特徴の複雑な理解に変換する内部表現を構築する必要があります。
これは、畳み込みニューラルネットワークがうまく一般化できる理由でもあります。それらは、バックグラウンドノイズやその他の迷惑にとらわれないクラス(猫と犬など)内の不変性と定義機能をキャプチャできます。したがって、生の画像がモデルに入力される場所と出力分類ラベルの間のどこかで、モデルは複雑な特徴抽出器として機能します。モデルの中間レイヤーにアクセスすることで、入力画像のコンテンツとスタイルを説明できます。
モデルを構築する
tf.keras.applications
のネットワークは、 tf.keras.applications
機能APIを使用して中間層の値を簡単に抽出できるように設計されています。
機能APIを使用してモデルを定義するには、入力と出力を指定します。
model = Model(inputs, outputs)
次の関数は、中間層の出力のリストを返すVGG19モデルを作成します。
def vgg_layers(layer_names):
""" Creates a vgg model that returns a list of intermediate output values."""
# Load our model. Load pretrained VGG, trained on imagenet data
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
vgg.trainable = False
outputs = [vgg.get_layer(name).output for name in layer_names]
model = tf.keras.Model([vgg.input], outputs)
return model
そして、モデルを作成するには:
style_extractor = vgg_layers(style_layers)
style_outputs = style_extractor(style_image*255)
#Look at the statistics of each layer's output
for name, output in zip(style_layers, style_outputs):
print(name)
print(" shape: ", output.numpy().shape)
print(" min: ", output.numpy().min())
print(" max: ", output.numpy().max())
print(" mean: ", output.numpy().mean())
print()
block1_conv1 shape: (1, 336, 512, 64) min: 0.0 max: 835.5256 mean: 33.97525 block2_conv1 shape: (1, 168, 256, 128) min: 0.0 max: 4625.8857 mean: 199.82687 block3_conv1 shape: (1, 84, 128, 256) min: 0.0 max: 8789.239 mean: 230.78099 block4_conv1 shape: (1, 42, 64, 512) min: 0.0 max: 21566.135 mean: 791.24005 block5_conv1 shape: (1, 21, 32, 512) min: 0.0 max: 3189.2542 mean: 59.179478
スタイルを計算する
画像の内容は、中間特徴マップの値で表されます。
結局のところ、画像のスタイルは、さまざまな機能マップ間の手段と相関関係によって説明できます。特徴ベクトルの外積を各位置でそれ自体と取り、その外積をすべての場所で平均することにより、この情報を含むグラム行列を計算します。このグラム行列は、特定のレイヤーについて次のように計算できます。
これは、 tf.linalg.einsum
関数を使用して簡潔に実装できます。
def gram_matrix(input_tensor):
result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
input_shape = tf.shape(input_tensor)
num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
return result/(num_locations)
スタイルとコンテンツを抽出する
スタイルとコンテンツテンサーを返すモデルを作成します。
class StyleContentModel(tf.keras.models.Model):
def __init__(self, style_layers, content_layers):
super(StyleContentModel, self).__init__()
self.vgg = vgg_layers(style_layers + content_layers)
self.style_layers = style_layers
self.content_layers = content_layers
self.num_style_layers = len(style_layers)
self.vgg.trainable = False
def call(self, inputs):
"Expects float input in [0,1]"
inputs = inputs*255.0
preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
outputs = self.vgg(preprocessed_input)
style_outputs, content_outputs = (outputs[:self.num_style_layers],
outputs[self.num_style_layers:])
style_outputs = [gram_matrix(style_output)
for style_output in style_outputs]
content_dict = {content_name:value
for content_name, value
in zip(self.content_layers, content_outputs)}
style_dict = {style_name:value
for style_name, value
in zip(self.style_layers, style_outputs)}
return {'content':content_dict, 'style':style_dict}
画像で呼び出されると、このモデルはstyle_layers
グラム行列(スタイル)とcontent_layers
コンテンツを返します。
extractor = StyleContentModel(style_layers, content_layers)
results = extractor(tf.constant(content_image))
print('Styles:')
for name, output in sorted(results['style'].items()):
print(" ", name)
print(" shape: ", output.numpy().shape)
print(" min: ", output.numpy().min())
print(" max: ", output.numpy().max())
print(" mean: ", output.numpy().mean())
print()
print("Contents:")
for name, output in sorted(results['content'].items()):
print(" ", name)
print(" shape: ", output.numpy().shape)
print(" min: ", output.numpy().min())
print(" max: ", output.numpy().max())
print(" mean: ", output.numpy().mean())
Styles: block1_conv1 shape: (1, 64, 64) min: 0.0055228462 max: 28014.562 mean: 263.79025 block2_conv1 shape: (1, 128, 128) min: 0.0 max: 61479.49 mean: 9100.949 block3_conv1 shape: (1, 256, 256) min: 0.0 max: 545623.44 mean: 7660.976 block4_conv1 shape: (1, 512, 512) min: 0.0 max: 4320502.0 mean: 134288.84 block5_conv1 shape: (1, 512, 512) min: 0.0 max: 110005.34 mean: 1487.0381 Contents: block5_conv2 shape: (1, 26, 32, 512) min: 0.0 max: 2410.8796 mean: 13.764149
勾配降下を実行します
このスタイルとコンテンツエクストラクタを使用すると、スタイル転送アルゴリズムを実装できます。これを行うには、各ターゲットに対する画像の出力の平均二乗誤差を計算し、これらの損失の加重和を取ります。
スタイルとコンテンツのターゲット値を設定します。
style_targets = extractor(style_image)['style']
content_targets = extractor(content_image)['content']
最適化する画像を含むtf.Variable
を定義します。これをすばやく行うには、コンテンツイメージで初期化します( tf.Variable
はコンテンツイメージと同じ形状である必要があります)。
image = tf.Variable(content_image)
これはフロート画像なので、ピクセル値を0から1の間に保つ関数を定義します。
def clip_0_1(image):
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
オプティマイザを作成します。この論文ではLBFGSを推奨していますが、 Adam
も問題なく機能します。
opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
これを最適化するには、2つの損失の加重組み合わせを使用して、合計損失を取得します。
style_weight=1e-2
content_weight=1e4
def style_content_loss(outputs):
style_outputs = outputs['style']
content_outputs = outputs['content']
style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2)
for name in style_outputs.keys()])
style_loss *= style_weight / num_style_layers
content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2)
for name in content_outputs.keys()])
content_loss *= content_weight / num_content_layers
loss = style_loss + content_loss
return loss
tf.GradientTape
を使用して画像を更新します。
@tf.function()
def train_step(image):
with tf.GradientTape() as tape:
outputs = extractor(image)
loss = style_content_loss(outputs)
grad = tape.gradient(loss, image)
opt.apply_gradients([(grad, image)])
image.assign(clip_0_1(image))
次に、いくつかの手順を実行してテストします。
train_step(image)
train_step(image)
train_step(image)
tensor_to_image(image)
動作しているので、より長い最適化を実行します。
import time
start = time.time()
epochs = 10
steps_per_epoch = 100
step = 0
for n in range(epochs):
for m in range(steps_per_epoch):
step += 1
train_step(image)
print(".", end='')
display.clear_output(wait=True)
display.display(tensor_to_image(image))
print("Train step: {}".format(step))
end = time.time()
print("Total time: {:.1f}".format(end-start))
Train step: 1000 Total time: 20.3
全変動損失
この基本的な実装の欠点の1つは、高頻度のアーティファクトが多数生成されることです。画像の高周波成分に明示的な正則化項を使用して、これらを減らします。スタイル転送では、これはしばしば総変動損失と呼ばれます:
def high_pass_x_y(image):
x_var = image[:,:,1:,:] - image[:,:,:-1,:]
y_var = image[:,1:,:,:] - image[:,:-1,:,:]
return x_var, y_var
x_deltas, y_deltas = high_pass_x_y(content_image)
plt.figure(figsize=(14,10))
plt.subplot(2,2,1)
imshow(clip_0_1(2*y_deltas+0.5), "Horizontal Deltas: Original")
plt.subplot(2,2,2)
imshow(clip_0_1(2*x_deltas+0.5), "Vertical Deltas: Original")
x_deltas, y_deltas = high_pass_x_y(image)
plt.subplot(2,2,3)
imshow(clip_0_1(2*y_deltas+0.5), "Horizontal Deltas: Styled")
plt.subplot(2,2,4)
imshow(clip_0_1(2*x_deltas+0.5), "Vertical Deltas: Styled")
これは、高周波成分がどのように増加したかを示しています。
また、この高周波成分は基本的にエッジ検出器です。 Sobelエッジ検出器から同様の出力を取得できます。次に例を示します。
plt.figure(figsize=(14,10))
sobel = tf.image.sobel_edges(content_image)
plt.subplot(1,2,1)
imshow(clip_0_1(sobel[...,0]/4+0.5), "Horizontal Sobel-edges")
plt.subplot(1,2,2)
imshow(clip_0_1(sobel[...,1]/4+0.5), "Vertical Sobel-edges")
これに関連する正則化損失は、値の2乗の合計です。
def total_variation_loss(image):
x_deltas, y_deltas = high_pass_x_y(image)
return tf.reduce_sum(tf.abs(x_deltas)) + tf.reduce_sum(tf.abs(y_deltas))
total_variation_loss(image).numpy()
149362.55
それはそれが何をするかを示しました。ただし、自分で実装する必要はありません。TensorFlowには標準の実装が含まれています。
tf.image.total_variation(image).numpy()
array([149362.55], dtype=float32)
最適化を再実行します
total_variation_loss
重みを選択してtotal_variation_loss
:
total_variation_weight=30
次に、 train_step
関数にtrain_step
ます。
@tf.function()
def train_step(image):
with tf.GradientTape() as tape:
outputs = extractor(image)
loss = style_content_loss(outputs)
loss += total_variation_weight*tf.image.total_variation(image)
grad = tape.gradient(loss, image)
opt.apply_gradients([(grad, image)])
image.assign(clip_0_1(image))
最適化変数を再初期化します。
image = tf.Variable(content_image)
そして、最適化を実行します。
import time
start = time.time()
epochs = 10
steps_per_epoch = 100
step = 0
for n in range(epochs):
for m in range(steps_per_epoch):
step += 1
train_step(image)
print(".", end='')
display.clear_output(wait=True)
display.display(tensor_to_image(image))
print("Train step: {}".format(step))
end = time.time()
print("Total time: {:.1f}".format(end-start))
Train step: 1000 Total time: 21.4
最後に、結果を保存します。
file_name = 'stylized-image.png'
tensor_to_image(image).save(file_name)
try:
from google.colab import files
except ImportError:
pass
else:
files.download(file_name)
もっと詳しく知る
このチュートリアルでは、元のスタイル転送アルゴリズムを示します。スタイル転送の簡単なアプリケーションについては、このチュートリアルをチェックして、 TensorFlowハブから任意の画像スタイル転送モデルを使用する方法の詳細を確認してください。