![]() |
![]() |
![]() |
![]() |
このチュートリアルには、Alexander Mordvintsev によるこちらのブログ記事で説明された DeepDream の最小限の実装が含まれます。
DeepDream はニューラルネットワークが学習したパターンを視覚化する実験です。子供が雲を見てなんらかの形に解釈しようとするのと同様に、DeepDream は過解釈を行って、画像に見いだせるパターンの精度を強化します。
ネットワークを通じて画像を転送し、特定のレイヤーのアクティベーションに関して画像の勾配を計算することで行われています。画像は、これらのアクティベーションを変更しながら、ネットワークに見られるパターンを強化して、夢の中のようなイメージを作り出します。このプロセスは、InceptionNet と、映画「インセプション」の因んで、「インセプショニズム」と呼ばれています。
では、ニューラルネットワークに「夢を見させて」、画像に見いだすシュールなパターンを強化する方法を実演することにしましょう。
import tensorflow as tf
import numpy as np
import matplotlib as mpl
import IPython.display as display
import PIL.Image
from tensorflow.keras.preprocessing import image
ドリーム化する画像を選択する
このチュートリアルでは、ラブラドールの画像を使用しましょう。
url = 'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg'
# Download an image and read it into a NumPy array.
def download(url, max_dim=None):
name = url.split('/')[-1]
image_path = tf.keras.utils.get_file(name, origin=url)
img = PIL.Image.open(image_path)
if max_dim:
img.thumbnail((max_dim, max_dim))
return np.array(img)
# Normalize an image
def deprocess(img):
img = 255*(img + 1.0)/2.0
return tf.cast(img, tf.uint8)
# Display an image
def show(img):
display.display(PIL.Image.fromarray(np.array(img)))
# Downsizing the image makes it easier to work with.
original_img = download(url, max_dim=500)
show(original_img)
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
特徴抽出モデルを準備する
事前トレーニング済みの画像分類モデルをダウンロードして準備します。もともと DeepDream で使用されたモデルに似た InceptionV3 を使用します。任意の事前トレーニング済みのモデルを使用することができますが、レイヤー名を変更する場合は、以下のように調整する必要があります。
base_model = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
2021-08-14 06:18:44.581540: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:44.588184: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:44.589137: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:44.590701: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2021-08-14 06:18:44.591179: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:44.592055: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:44.592876: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:45.166523: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:45.167590: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:45.168510: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2021-08-14 06:18:45.169320: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14648 MB memory: -> device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0 Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 87916544/87910968 [==============================] - 1s 0us/step 87924736/87910968 [==============================] - 1s 0us/step
DeepDream の考え方は、レイヤーを選択して、画像がレイヤーを徐々に「刺激する」ように「損失」を最大化することです。追加する特徴量の複雑さは、あなたが選択するレイヤーによって異なり、レイヤーが低ければストロークや単純なパターンを生成し、レイヤーが深くなるほど、画像または画像全体の特徴がより洗練されることになります。
InceptionV3 アーキテクチャは非常に大型です(モデルアーキテクチャのグラフについては、TensorFlow の research リポジトリをご覧ください)。DeepDream では、対象のレイヤーは畳み込みが連結されている場所です。こういったレイヤーは InceptionV3 には 11 個あり、'mixed0' から 'mixed10' の名前が付けられています。異なるレイヤーを使用すると、異なった夢のような画像が生成されます。レイヤーが深くなるほどより高度な特徴(目や顔など)に対応し、浅いほどよりシンプルな特徴(エッジ、形状、テクスチャなど)に対応します。以下で選択するレイヤーを自由に試してみてください。ただし、レイヤーが深くなるほど(インデックスが高いレイヤー)、勾配の計算がより深くなるため、トレーニングに時間がかかることに注意してください。
# Maximize the activations of these layers
names = ['mixed3', 'mixed5']
layers = [base_model.get_layer(name).output for name in names]
# Create the feature extraction model
dream_model = tf.keras.Model(inputs=base_model.input, outputs=layers)
損失を計算する
損失は、選択されたレイヤーのアクティベーションの和です。損失はレイヤーごとに正規化されるため、より大きなレイヤーからの貢献は小さなレイヤーを上回らないようになっています。通常、損失は、勾配降下法で最小化する量ですが、DeepDream では、勾配上昇法によってこの損失を最大化します。
def calc_loss(img, model):
# Pass forward the image through the model to retrieve the activations.
# Converts the image into a batch of size 1.
img_batch = tf.expand_dims(img, axis=0)
layer_activations = model(img_batch)
if len(layer_activations) == 1:
layer_activations = [layer_activations]
losses = []
for act in layer_activations:
loss = tf.math.reduce_mean(act)
losses.append(loss)
return tf.reduce_sum(losses)
勾配上昇法
選択したレイヤーの損失を計算したら、後は画像に関して勾配を計算し、それを元の画像に追加するだけです。
画像に勾配を追加すると、ネットワークが見るパターンの精度が上がります。各ステップで、ネットワークの特定のレイヤーのアクティベーションを徐々に刺激する画像を作成することになります。
これを行うメソッドは、パフォーマンスを得られるように tf.function
でラッピングされます。input_signature
を使用するため、さまざまな画像サイズまたは steps
/step_size
値で関数が再トレースされないようになっています。詳細は、具象関数ガイドをご覧ください。
class DeepDream(tf.Module):
def __init__(self, model):
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
tf.TensorSpec(shape=[], dtype=tf.int32),
tf.TensorSpec(shape=[], dtype=tf.float32),)
)
def __call__(self, img, steps, step_size):
print("Tracing")
loss = tf.constant(0.0)
for n in tf.range(steps):
with tf.GradientTape() as tape:
# This needs gradients relative to `img`
# `GradientTape` only watches `tf.Variable`s by default
tape.watch(img)
loss = calc_loss(img, self.model)
# Calculate the gradient of the loss with respect to the pixels of the input image.
gradients = tape.gradient(loss, img)
# Normalize the gradients.
gradients /= tf.math.reduce_std(gradients) + 1e-8
# In gradient ascent, the "loss" is maximized so that the input image increasingly "excites" the layers.
# You can update the image by directly adding the gradients (because they're the same shape!)
img = img + gradients*step_size
img = tf.clip_by_value(img, -1, 1)
return loss, img
deepdream = DeepDream(dream_model)
メインのループ
def run_deep_dream_simple(img, steps=100, step_size=0.01):
# Convert from uint8 to the range expected by the model.
img = tf.keras.applications.inception_v3.preprocess_input(img)
img = tf.convert_to_tensor(img)
step_size = tf.convert_to_tensor(step_size)
steps_remaining = steps
step = 0
while steps_remaining:
if steps_remaining>100:
run_steps = tf.constant(100)
else:
run_steps = tf.constant(steps_remaining)
steps_remaining -= run_steps
step += run_steps
loss, img = deepdream(img, run_steps, tf.constant(step_size))
display.clear_output(wait=True)
show(deprocess(img))
print ("Step {}, loss {}".format(step, loss))
result = deprocess(img)
display.clear_output(wait=True)
show(result)
return result
dream_img = run_deep_dream_simple(img=original_img,
steps=100, step_size=0.01)
オクターブを実行する
ここまでで非常に素晴らしいものではありますが、この最初の試行にはいくつかの問題があります。
- 出力にノイズがある(
tf.image.total_variation
損失で解消可能)。 - 画像解像度が低い。
- パターンが同じ粒度で発生しているように見える。
上記のすべての問題を解決するには、1 つのアプローチとして、異なるスケールで勾配上昇法を適用することが挙げられます。こうすれば、より小さなスケールで生成されたパターンをより高いスケールのパターンに統合して、追加の詳細で満たすことができます。
これを行うには、上述の勾配上昇法を実行してから、画像のサイズを増加し(これをオクターブと呼びます)、このプロセスを複数のオクターブで繰り返します。
import time
start = time.time()
OCTAVE_SCALE = 1.30
img = tf.constant(np.array(original_img))
base_shape = tf.shape(img)[:-1]
float_base_shape = tf.cast(base_shape, tf.float32)
for n in range(-2, 3):
new_shape = tf.cast(float_base_shape*(OCTAVE_SCALE**n), tf.int32)
img = tf.image.resize(img, new_shape).numpy()
img = run_deep_dream_simple(img=img, steps=50, step_size=0.01)
display.clear_output(wait=True)
img = tf.image.resize(img, base_shape)
img = tf.image.convert_image_dtype(img/255.0, dtype=tf.uint8)
show(img)
end = time.time()
end-start
6.839528560638428
オプション: タイルでスケールアップする
画像サイズが大きくなるにつれ、勾配計算の実行に必要な時間とメモリ量も高まるということに注意する必要があります。上記のオクターブ実装は、非常に大きな画像や多数のオクターブでは機能しません。
この問題を回避するには、画像をタイルに分割して、各タイルに対して勾配を計算することができます。
それぞれのタイル計算を行う前に画像にランダムシフトを適用すると、タイルの継ぎ目が現れなくなります。
ランダムシフトの実装から始めましょう。
def random_roll(img, maxroll):
# Randomly shift the image to avoid tiled boundaries.
shift = tf.random.uniform(shape=[2], minval=-maxroll, maxval=maxroll, dtype=tf.int32)
img_rolled = tf.roll(img, shift=shift, axis=[0,1])
return shift, img_rolled
shift, img_rolled = random_roll(np.array(original_img), 512)
show(img_rolled)
以下は、前に定義した deepdream
関数のタイルバージョンです。
class TiledGradients(tf.Module):
def __init__(self, model):
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
tf.TensorSpec(shape=[], dtype=tf.int32),)
)
def __call__(self, img, tile_size=512):
shift, img_rolled = random_roll(img, tile_size)
# Initialize the image gradients to zero.
gradients = tf.zeros_like(img_rolled)
# Skip the last tile, unless there's only one tile.
xs = tf.range(0, img_rolled.shape[0], tile_size)[:-1]
if not tf.cast(len(xs), bool):
xs = tf.constant([0])
ys = tf.range(0, img_rolled.shape[1], tile_size)[:-1]
if not tf.cast(len(ys), bool):
ys = tf.constant([0])
for x in xs:
for y in ys:
# Calculate the gradients for this tile.
with tf.GradientTape() as tape:
# This needs gradients relative to `img_rolled`.
# `GradientTape` only watches `tf.Variable`s by default.
tape.watch(img_rolled)
# Extract a tile out of the image.
img_tile = img_rolled[x:x+tile_size, y:y+tile_size]
loss = calc_loss(img_tile, self.model)
# Update the image gradients for this tile.
gradients = gradients + tape.gradient(loss, img_rolled)
# Undo the random shift applied to the image and its gradients.
gradients = tf.roll(gradients, shift=-shift, axis=[0,1])
# Normalize the gradients.
gradients /= tf.math.reduce_std(gradients) + 1e-8
return gradients
get_tiled_gradients = TiledGradients(dream_model)
これを合わせると、スケーラブルなオクターブ対応の DeepDream 実装が得られます。
def run_deep_dream_with_octaves(img, steps_per_octave=100, step_size=0.01,
octaves=range(-2,3), octave_scale=1.3):
base_shape = tf.shape(img)
img = tf.keras.preprocessing.image.img_to_array(img)
img = tf.keras.applications.inception_v3.preprocess_input(img)
initial_shape = img.shape[:-1]
img = tf.image.resize(img, initial_shape)
for octave in octaves:
# Scale the image based on the octave
new_size = tf.cast(tf.convert_to_tensor(base_shape[:-1]), tf.float32)*(octave_scale**octave)
img = tf.image.resize(img, tf.cast(new_size, tf.int32))
for step in range(steps_per_octave):
gradients = get_tiled_gradients(img)
img = img + gradients*step_size
img = tf.clip_by_value(img, -1, 1)
if step % 10 == 0:
display.clear_output(wait=True)
show(deprocess(img))
print ("Octave {}, Step {}".format(octave, step))
result = deprocess(img)
return result
img = run_deep_dream_with_octaves(img=original_img, step_size=0.01)
display.clear_output(wait=True)
img = tf.image.resize(img, base_shape)
img = tf.image.convert_image_dtype(img/255.0, dtype=tf.uint8)
show(img)
断然に良くなりました!オクターブ、オクターブスケール、アクティベーションされたレイヤーをいろいろ試して、DeepDream 化された画像の変化を確認してみてください。
このチュートリアルで紹介した考え方をさらに拡大した、ニューラルネットワークの視覚化と解釈を行う TensorFlow Lucid というものもありますので、ぜひお試しください。