딥드림 (DeepDream)

TensorFlow.org에서 보기 구글 코랩(Colab)에서 실행하기 깃허브(GitHub) 소스 보기 Download notebook

본 튜토리얼은 Alexander Mordvintsev가 이 블로그 포스트에서 설명한 딥드림의 간단한 구현 방법에 대해 소개합니다.

딥드림은 신경망이 학습한 패턴을 시각화하는 실험입니다. 어린 아이가 구름을 보며 임의의 모양을 상상하는 것처럼, 딥드림은 주어진 이미지 내에 있는 패턴을 향상시키고 과잉해석합니다.

이 효과는 입력 이미지를 신경망을 통해 순전파 한 후, 특정 층의 활성화값에 대한 이미지의 그래디언트(gradient)를 계산함으로써 구현할 수 있습니다. 딥드림 알고리즘은 층의 활성화값을 최대화하도록 이미지를 수정하는데, 이는 신경망으로 하여금 특정 패턴을 과잉해석 하도록 합니다. 이로써 입력 이미지를 기반으로 하는 몽환적인 이미지를 만들어낼 수 있게 됩니다. 이 과정은 InceptionNet 모델과 영화 인셉션에서 이름을 따와 "Inceptionism"이라고 부릅니다.

그럼 이제 어떻게 하면 신경망으로 하여금 이미지에 있는 초현실적인 패턴을 향상시키고 “꿈”을 꾸도록 만들 수 있는지 알아보도록 하겠습니다.

Dogception

import tensorflow as tf
2022-12-15 01:01:25.870129: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 01:01:25.870248: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 01:01:25.870262: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
import numpy as np

import matplotlib as mpl

import IPython.display as display
import PIL.Image

변환(dream-ify)할 이미지 선택하기

이 튜토리얼에서는 래브라도의 이미지를 사용하겠습니다.

url = 'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg'
# Download an image and read it into a NumPy array.
def download(url, max_dim=None):
  name = url.split('/')[-1]
  image_path = tf.keras.utils.get_file(name, origin=url)
  img = PIL.Image.open(image_path)
  if max_dim:
    img.thumbnail((max_dim, max_dim))
  return np.array(img)

# Normalize an image
def deprocess(img):
  img = 255*(img + 1.0)/2.0
  return tf.cast(img, tf.uint8)

# Display an image
def show(img):
  display.display(PIL.Image.fromarray(np.array(img)))


# Downsizing the image makes it easier to work with.
original_img = download(url, max_dim=500)
show(original_img)
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

png

특성 추출 모델 (feature extraction model) 준비하기

사전 훈련된 이미지 분류 모델을 내려받습니다. 본 튜토리얼에서는 딥드림에서 원래 사용된 모델과 유사한 InceptionV3를 사용합니다. 다른 사전 학습된 모델을 사용해도 괜찮지만, 코드 구현 시 층의 이름을 수정해야 합니다.

base_model = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
87910968/87910968 [==============================] - 1s 0us/step

딥드림은 활성화시킬 하나 혹은 그 이상의 층을 선택한 후 "손실"을 최대화하도록 이미지를 수정함으로써 선택한 층을 "흥분"시키는 원리를 기반으로 합니다. 얼마나 복잡한 특성이 나타날지는 선택한 층에 따라 다르게 됩니다. 낮은 층을 선택한다면 획 또는 간단한 패턴이 향상되고, 깊은 층을 선택한다면 이미지 내의 복잡한 패턴이나 심지어 물체의 모습도 생성할 수 있습니다.

InceptionV3는 상당히 큰 모델입니다 (모델 구조의 그래프는 텐서플로 연구 저장소에서 확인할 수 있습니다). 모델의 많은 층들 중, 딥드림을 구현하기 위해 살펴보아야 할 곳은 합성곱층들이 연결된 부분입니다. InceptionV3에는 'mixed0'부터 'mixed10'까지 총 11개의 이러한 합성곱층이 있습니다. 이 중 어떤 층을 선택하느냐에 따라서 딥드림 이미지의 모습이 결정됩니다. 깊은 층은 눈이나 얼굴과 같은 고차원 특성(higher-level features)에 반응하는 반면, 낮은 층은 선분이나 모양, 질감과 같은 저차원 특성에 반응합니다. 임의의 층을 선택해 자유롭게 실험해보는 것도 가능합니다. 다만 깊은 층(인덱스가 높은 층)은 훈련을 위한 그래디언트 계산에 시간이 오래 걸릴 수 있습니다.

# Maximize the activations of these layers
names = ['mixed3', 'mixed5']
layers = [base_model.get_layer(name).output for name in names]

# Create the feature extraction model
dream_model = tf.keras.Model(inputs=base_model.input, outputs=layers)

손실 계산하기

손실은 선택한 층들의 활성화값의 총 합으로 계산됩니다. 층의 크기와 상관 없이 모든 활성화값이 동일하게 고려될 수 있도록 각 층의 손실을 정규화합니다. 일반적으로, 손실은 경사하강법으로 최소화하고자 하는 수치입니다. 하지만 딥드림에서는 예외적으로 이 손실을 경사상승법(gradient ascent)을 통해 최대화할 것입니다.

def calc_loss(img, model):
  # Pass forward the image through the model to retrieve the activations.
  # Converts the image into a batch of size 1.
  img_batch = tf.expand_dims(img, axis=0)
  layer_activations = model(img_batch)
  if len(layer_activations) == 1:
    layer_activations = [layer_activations]

  losses = []
  for act in layer_activations:
    loss = tf.math.reduce_mean(act)
    losses.append(loss)

  return  tf.reduce_sum(losses)

경사상승법

선택한 층의 손실을 구했다면, 이제 남은 순서는 입력 이미지에 대한 그래디언트를 계산하여 원본 이미지에 추가하는 것입니다.

원본 이미지에 그래디언트를 더하는 것은 신경망이 보는 이미지 내의 패턴을 향상시키는 일에 해당합니다. 훈련이 진행될수록 신경망에서 선택한 층을 더욱더 활성화시키는 이미지를 생성할 수 있습니다.

이 작업을 수행하는 아래의 함수는 성능을 최적화하기 위해 tf.function으로 감쌉니다. 이 함수는 input_signature를 이용해 함수가 다른 이미지 크기 혹은 step/step_size값에 대해 트레이싱(tracing)되지 않도록 합니다. 보다 자세한 설명을 위해서는 Concrete functions guide를 참조합니다.

class DeepDream(tf.Module):
  def __init__(self, model):
    self.model = model

  @tf.function(
      input_signature=(
        tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
        tf.TensorSpec(shape=[], dtype=tf.int32),
        tf.TensorSpec(shape=[], dtype=tf.float32),)
  )
  def __call__(self, img, steps, step_size):
      print("Tracing")
      loss = tf.constant(0.0)
      for n in tf.range(steps):
        with tf.GradientTape() as tape:
          # This needs gradients relative to `img`
          # `GradientTape` only watches `tf.Variable`s by default
          tape.watch(img)
          loss = calc_loss(img, self.model)

        # Calculate the gradient of the loss with respect to the pixels of the input image.
        gradients = tape.gradient(loss, img)

        # Normalize the gradients.
        gradients /= tf.math.reduce_std(gradients) + 1e-8 

        # In gradient ascent, the "loss" is maximized so that the input image increasingly "excites" the layers.
        # You can update the image by directly adding the gradients (because they're the same shape!)
        img = img + gradients*step_size
        img = tf.clip_by_value(img, -1, 1)

      return loss, img
deepdream = DeepDream(dream_model)

주 루프

def run_deep_dream_simple(img, steps=100, step_size=0.01):
  # Convert from uint8 to the range expected by the model.
  img = tf.keras.applications.inception_v3.preprocess_input(img)
  img = tf.convert_to_tensor(img)
  step_size = tf.convert_to_tensor(step_size)
  steps_remaining = steps
  step = 0
  while steps_remaining:
    if steps_remaining>100:
      run_steps = tf.constant(100)
    else:
      run_steps = tf.constant(steps_remaining)
    steps_remaining -= run_steps
    step += run_steps

    loss, img = deepdream(img, run_steps, tf.constant(step_size))

    display.clear_output(wait=True)
    show(deprocess(img))
    print ("Step {}, loss {}".format(step, loss))


  result = deprocess(img)
  display.clear_output(wait=True)
  show(result)

  return result
dream_img = run_deep_dream_simple(img=original_img, 
                                  steps=100, step_size=0.01)

png

한 옥타브 (octave) 올라가기

지금 생성된 이미지도 상당히 인상적이지만, 위의 시도는 몇 가지 문제점들을 안고 있습니다:

  1. 생성된 이미지에 노이즈(noise)가 많이 끼어있습니다 (이 문제는 tf.image.total_variation loss로 해결할 수 있습니다).
  2. 생성된 이미지의 해상도가 낮습니다.
  3. 패턴들이 모두 균일한 입도(granularity)로 나타나고 있습니다.

이 문제들을 해결할 수 있는 한 가지 방법은 경사상승법의 스케일(scale)를 달리하여 여러 차례 적용하는 것입니다. 이는 작은 스케일에서 생성된 패턴들이 큰 스케일에서 생성된 패턴들에 녹아들어 더 많은 디테일을 형성할 수 있도록 해줍니다.

이 작업을 실행하기 위해서는 이전에 구현한 경사상승법을 사용한 후, 이미지의 크기(이를 옥타브라고 부릅니다)를 키우고 여러 옥타브에 대해 이 과정을 반복합니다.

import time
start = time.time()

OCTAVE_SCALE = 1.30

img = tf.constant(np.array(original_img))
base_shape = tf.shape(img)[:-1]
float_base_shape = tf.cast(base_shape, tf.float32)

for n in range(-2, 3):
  new_shape = tf.cast(float_base_shape*(OCTAVE_SCALE**n), tf.int32)

  img = tf.image.resize(img, new_shape).numpy()

  img = run_deep_dream_simple(img=img, steps=50, step_size=0.01)

display.clear_output(wait=True)
img = tf.image.resize(img, base_shape)
img = tf.image.convert_image_dtype(img/255.0, dtype=tf.uint8)
show(img)

end = time.time()
end-start

png

8.16125226020813

선택 사항: 타일(tile)을 이용해 이미지 확장하기

한 가지 고려할 사항은 이미지의 크기가 커질수록 그래디언트 계산에 소요되는 시간과 메모리가 늘어난다는 점입니다. 따라서 위에서 구현한 방식은 옥타브 수가 많거나 입력 이미지의 해상도가 높은 상황에는 적합하지 않습니다.

입력 이미지를 여러 타일로 나눠 각 타일에 대해 그래디언트를 계산하면 이 문제를 피할 수 있습니다.

각 타일별 그래디언트를 계산하기 전에 이미지를 랜덤하게 이동하면 타일 사이의 이음새가 나타나는 것을 방지할 수 있습니다.

이미지를 랜덤하게 이동시키는 작업부터 시작합니다:

def random_roll(img, maxroll):
  # Randomly shift the image to avoid tiled boundaries.
  shift = tf.random.uniform(shape=[2], minval=-maxroll, maxval=maxroll, dtype=tf.int32)
  img_rolled = tf.roll(img, shift=shift, axis=[0,1])
  return shift, img_rolled
shift, img_rolled = random_roll(np.array(original_img), 512)
show(img_rolled)

png

앞서 정의한 deepdream 함수에 타일 기능을 추가합니다:

class TiledGradients(tf.Module):
  def __init__(self, model):
    self.model = model

  @tf.function(
      input_signature=(
        tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
        tf.TensorSpec(shape=[2], dtype=tf.int32),
        tf.TensorSpec(shape=[], dtype=tf.int32),)
  )
  def __call__(self, img, img_size, tile_size=512):
    shift, img_rolled = random_roll(img, tile_size)

    # Initialize the image gradients to zero.
    gradients = tf.zeros_like(img_rolled)

    # Skip the last tile, unless there's only one tile.
    xs = tf.range(0, img_size[1], tile_size)[:-1]
    if not tf.cast(len(xs), bool):
      xs = tf.constant([0])
    ys = tf.range(0, img_size[0], tile_size)[:-1]
    if not tf.cast(len(ys), bool):
      ys = tf.constant([0])

    for x in xs:
      for y in ys:
        # Calculate the gradients for this tile.
        with tf.GradientTape() as tape:
          # This needs gradients relative to `img_rolled`.
          # `GradientTape` only watches `tf.Variable`s by default.
          tape.watch(img_rolled)

          # Extract a tile out of the image.
          img_tile = img_rolled[y:y+tile_size, x:x+tile_size]
          loss = calc_loss(img_tile, self.model)

        # Update the image gradients for this tile.
        gradients = gradients + tape.gradient(loss, img_rolled)

    # Undo the random shift applied to the image and its gradients.
    gradients = tf.roll(gradients, shift=-shift, axis=[0,1])

    # Normalize the gradients.
    gradients /= tf.math.reduce_std(gradients) + 1e-8 

    return gradients
get_tiled_gradients = TiledGradients(dream_model)

이 모든 것을 종합하면 확장 가능한 (scalabe) 옥타브 기반의 딥드림 구현이 완성됩니다:

def run_deep_dream_with_octaves(img, steps_per_octave=100, step_size=0.01, 
                                octaves=range(-2,3), octave_scale=1.3):
  base_shape = tf.shape(img)
  img = tf.keras.utils.img_to_array(img)
  img = tf.keras.applications.inception_v3.preprocess_input(img)

  initial_shape = img.shape[:-1]
  img = tf.image.resize(img, initial_shape)
  for octave in octaves:
    # Scale the image based on the octave
    new_size = tf.cast(tf.convert_to_tensor(base_shape[:-1]), tf.float32)*(octave_scale**octave)
    new_size = tf.cast(new_size, tf.int32)
    img = tf.image.resize(img, new_size)

    for step in range(steps_per_octave):
      gradients = get_tiled_gradients(img, new_size)
      img = img + gradients*step_size
      img = tf.clip_by_value(img, -1, 1)

      if step % 10 == 0:
        display.clear_output(wait=True)
        show(deprocess(img))
        print ("Octave {}, Step {}".format(octave, step))

  result = deprocess(img)
  return result
img = run_deep_dream_with_octaves(img=original_img, step_size=0.01)

display.clear_output(wait=True)
img = tf.image.resize(img, base_shape)
img = tf.image.convert_image_dtype(img/255.0, dtype=tf.uint8)
show(img)

png

지금까지 본 이미지보다 더 흥미로운 결과가 생성되었습니다! 이제는 옥타브의 수, 스케일, 그리고 활성화 층을 달리해가며 어떤 딥드림 이미지를 만들 수 있는지 실험해볼 차례입니다.

이 튜토리얼에서 소개된 기법 외에 신경망의 활성화를 해석하고 시각화할 수 있는 더 많은 방법에 대해 알아보고 싶다면 텐서플로 루시드를 참고합니다.