CelebA Progressive GAN モデルで人工顔を生成する

TensorFlow.org で表示 Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

この Colab では、生成敵対的ネットワーク(GAN)に基づく TF-Hub モジュールの使用を実現します。このモジュールは、潜在空間と呼ばれる N 次元のベクトルから RGB 画像へのマッピングを行います。

次の 2 つの例が提供されています。

  • 潜在空間から画像へのマッピング
  • 特定のターゲット画像がある場合、ターゲット画像に似た画像を生成する潜在ベクトルを求めるために勾配降下を使用する。

オプションの前提条件

その他のモデル

こちらでは、現在 tfhub.dev にホストされている、画像を生成できるすべてのモデルをご覧いただけます。

セットアップ

# Install imageio for creating animations.
pip -q install imageio
pip -q install scikit-image
pip install git+https://github.com/tensorflow/docs

Imports and function definitions

2024-01-11 19:47:45.233809: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 19:47:45.233856: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 19:47:45.235413: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

潜在空間の補間

ランダムなベクトル

2 つのランダムに初期化されたベクトル間の潜在空間の補間です。トレーニング済みの Progressive GAN を含む TF-Hub モジュール progan-128 を使用します。

progan = hub.load("https://tfhub.dev/google/progan-128/1").signatures['default']
def interpolate_between_vectors():
  v1 = tf.random.normal([latent_dim])
  v2 = tf.random.normal([latent_dim])

  # Creates a tensor with 25 steps of interpolation between v1 and v2.
  vectors = interpolate_hypersphere(v1, v2, 50)

  # Uses module to generate images from the latent space.
  interpolated_images = progan(vectors)['default']

  return interpolated_images

interpolated_images = interpolate_between_vectors()
animate(interpolated_images)

gif

潜在空間の最も近いベクトルを見つける

ターゲット画像を修正します。例として、モジュールから生成された画像を使用するか、独自の画像をアップロードします。

image_from_module_space = True  # @param { isTemplate:true, type:"boolean" }

def get_module_space_image():
  vector = tf.random.normal([1, latent_dim])
  images = progan(vector)['default'][0]
  return images

def upload_image():
  uploaded = files.upload()
  image = imageio.imread(uploaded[list(uploaded.keys())[0]])
  return transform.resize(image, [128, 128])

if image_from_module_space:
  target_image = get_module_space_image()
else:
  target_image = upload_image()

display_image(target_image)

png

ターゲット画像と潜在空間変数によって生成された画像の間の損失関数を定義した後、勾配降下を使用して、損失を最小限に抑える変数を見つけることができます。

tf.random.set_seed(42)
initial_vector = tf.random.normal([1, latent_dim])
display_image(progan(initial_vector)['default'][0])

png

def find_closest_latent_vector(initial_vector, num_optimization_steps,
                               steps_per_image):
  images = []
  losses = []

  vector = tf.Variable(initial_vector)  
  optimizer = tf.optimizers.Adam(learning_rate=0.01)
  loss_fn = tf.losses.MeanAbsoluteError(reduction="sum")

  for step in range(num_optimization_steps):
    if (step % 100)==0:
      print()
    print('.', end='')
    with tf.GradientTape() as tape:
      image = progan(vector.read_value())['default'][0]
      if (step % steps_per_image) == 0:
        images.append(image.numpy())
      target_image_difference = loss_fn(image, target_image[:,:,:3])
      # The latent vectors were sampled from a normal distribution. We can get
      # more realistic images if we regularize the length of the latent vector to 
      # the average length of vector from this distribution.
      regularizer = tf.abs(tf.norm(vector) - np.sqrt(latent_dim))

      loss = target_image_difference + regularizer
      losses.append(loss.numpy())
    grads = tape.gradient(loss, [vector])
    optimizer.apply_gradients(zip(grads, [vector]))

  return images, losses


num_optimization_steps=200
steps_per_image=5
images, loss = find_closest_latent_vector(initial_vector, num_optimization_steps, steps_per_image)
.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705002501.170772  126050 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
...................................................................................................
....................................................................................................
plt.plot(loss)
plt.ylim([0,max(plt.ylim())])
(0.0, 6696.34951171875)

png

animate(np.stack(images))

gif

結果をターゲットと比較します。

display_image(np.concatenate([images[-1], target_image], axis=1))

png

上記の例を使って試す

画像がモジュールの空間から得たものである場合、降下は急であり、合理的なサンプルに収束します。モジュール空間からではない画像に降下法を試してみましょう。降下は、画像が合理的に、トレーニング画像の空間に近い場合にのみ収束します。

より現実的な画像への降下を高速化するには、次の項目を試すことができます。

  • 画像微分(二次微分など)に別の損失を使用する
  • 潜在ベクトルに別のレギュラライザーを使用する
  • 複数の実行において、ランダムなベクトルから初期化する
  • など