学習データの圧縮

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

このノートブックでは、ニューラルネットワークと TensorFlow Compression を使って非可逆データ圧縮を行う方法を説明します。

非可逆圧縮には、レート、サンプルの安藤かに必要な期待されるビット数、およびサンプルの再構築における期待誤差を示すひずみ間のトレードオフが伴います。

以下の例では、オートエンコーダのようなモデルを使用して、MNIST データセットの画像を圧縮します。この手法は、『End-to-end Optimized Image Compression』という論文を基としています。

学習データの圧縮に関する背景情報については、古典的なデータ圧縮に精通した人を対象としたこちらの論文か、機械学習分野のユーザーを対象としたこちらの調査をご覧ください。

セットアップ

pip で Tensorflow Compression をインストールします。

# Installs the latest version of TFC compatible with the installed TF version.

read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"

ライブラリ依存関係をインポートします。

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
2022-12-14 20:45:16.609619: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:45:16.609721: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:45:16.609732: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

トレーナーモデルを定義する

このモデルはオートエンコーダに似ているため、またトレーニングと推論中に異なる機能を実行する必要があるため、このセットアップは、たとえば分類器などとは少し異なります。

トレーニングモデルは、以下の 3 つで構成されています。

  • 分析(またはエンコーダ)変換: 画像を潜在空間に変換します。
  • 合成(またはデコーダ)変換: 潜在空間から画像空間に変換します。
  • 事前確率とエントロピーモデル: 潜在空間の周辺分布をモデル化します。

まず、変換を定義します。

def make_analysis_transform(latent_dims):
  """Creates the analysis (encoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Conv2D(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2D(
          50, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          latent_dims, use_bias=True, activation=None, name="fc_2"),
  ], name="analysis_transform")
def make_synthesis_transform():
  """Creates the synthesis (decoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          2450, use_bias=True, activation="leaky_relu", name="fc_2"),
      tf.keras.layers.Reshape((7, 7, 50)),
      tf.keras.layers.Conv2DTranspose(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2DTranspose(
          1, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
  ], name="synthesis_transform")

トレーナーは、両方の変換のインスタンスと事前確率のパラメータを保有します。

その call メソッドは、以下を計算するようにセットアップされます。

  • レート: 数字のバッチを表現するために必要なビット数の推定
  • ひずみ: 元の数字と再構築された数字のピクセルの平均絶対差
class MNISTCompressionTrainer(tf.keras.Model):
  """Model that trains a compressor/decompressor for MNIST."""

  def __init__(self, latent_dims):
    super().__init__()
    self.analysis_transform = make_analysis_transform(latent_dims)
    self.synthesis_transform = make_synthesis_transform()
    self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))

  @property
  def prior(self):
    return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))

  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    x = tf.reshape(x, (-1, 28, 28, 1))

    # Compute latent space representation y, perturb it and model its entropy,
    # then compute the reconstructed pixel-level representation x_hat.
    y = self.analysis_transform(x)
    entropy_model = tfc.ContinuousBatchedEntropyModel(
        self.prior, coding_rank=1, compression=False)
    y_tilde, rate = entropy_model(y, training=training)
    x_tilde = self.synthesis_transform(y_tilde)

    # Average number of bits per MNIST digit.
    rate = tf.reduce_mean(rate)

    # Mean absolute difference across pixels.
    distortion = tf.reduce_mean(abs(x - x_tilde))

    return dict(rate=rate, distortion=distortion)

レートとひずみを計算する

では、トレーニングセットの画像を 1 つ使用して、順を追って説明します。トレーニングと検証用の MNIST データセットを読み込みます。

training_dataset, validation_dataset = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=False,
)

1 つの画像 \(x\) を抽出します。

(x, _), = validation_dataset.take(1)

plt.imshow(tf.squeeze(x))
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)
2022-12-14 20:45:23.216321: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

潜在の表現 \(y\) を取得するには、float32 にキャストして batch 次元を追加し、それを分析変換に通す必要があります。

x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)

print("y:", y)
y: tf.Tensor(
[[ 0.03988452 -0.02631121 -0.05344866 -0.04364791  0.06735273 -0.00989169
  -0.05671643 -0.01362787 -0.0330795  -0.03137782]], shape=(1, 10), dtype=float32)

潜在は、テスト時に量子化されます。これをトレーニング中に区別可能な方法でモデル化するために、\((-.5, .5)\) の間隔で一様ノイズを追加し、その結果を \(\tilde y\) をとします。これは、『End-to-end Optimized Image Compression』論文で使用されているのと同じです。

y_tilde = y + tf.random.uniform(y.shape, -.5, .5)

print("y_tilde:", y_tilde)
y_tilde: tf.Tensor(
[[-0.00847201 -0.19141883  0.01641406 -0.3539529   0.10089394  0.23102134
   0.12243177  0.16727103  0.22074556  0.08011779]], shape=(1, 10), dtype=float32)

「事前確率」は、ノイズを含む潜在の周辺分布をモデル化するためにトレーニングする分布の密度です。たとえば、潜在次元ごとに異なるスケールを持つ独立した一連のロジスティック分布であることがあります。tfc.NoisyLogistic は、潜在には追加ノイズがあるという事実を考慮します。スケールがゼロに近づくにつれ、ロジスティック分布はディラックのデルタ(スパイク)に近づくものですが、追加ノイズにより、「ノイズの多い」分布は一様分布に近づきます。

prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))

_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));

png

トレーニング中、tfc.ContinuousBatchedEntropyModel は一様ノイズを追加し、そのノイズと事前確率を使用して(区別可能な)レート(潜在表現をエンコードするために必要な平均ビット数)の上限を計算します。この上限は、損失として最小化できます。

entropy_model = tfc.ContinuousBatchedEntropyModel(
    prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)

print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.430876], shape=(1,), dtype=float32)
y_tilde: tf.Tensor(
[[ 0.00818415  0.4172811  -0.05954609  0.3539252  -0.02196757  0.2851495
  -0.00319849 -0.15237509 -0.46015334  0.07735881]], shape=(1, 10), dtype=float32)

最後に、ノイズのある潜在が合成変換を通過し、画像の再構築 \(\tilde x\) が生成されます。明らかに、変換はトレーニングされていないため、この再構築にはあまり利用価値がありません。

x_tilde = make_synthesis_transform()(y_tilde)

# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)

x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
plt.imshow(tf.squeeze(x_tilde))
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17073585, shape=(), dtype=float32)
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)

png

数字のバッチごとに MNISTCompressionTrainer を呼び出すと、レートとそのバッチの平均としてのひずみが生成されます。

(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)

print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate:  tf.Tensor(20.296253, shape=(), dtype=float32)
distortion:  tf.Tensor(0.14659302, shape=(), dtype=float32)
2022-12-14 20:45:25.322149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

次のセクションでは、これらの 2 つの損失で勾配降下を行うようにモデルをセットアップします。

モデルをトレーニングする

レートとひずみのラグアンジアン、つまりレートとひずみの和を最適化するようにトレーナーをコンパイルします。ここで、いずれかの項はラグランジュ関数パラメータ \(\lambda\) で重み付けされます。

この損失関数は、モデルのさまざまな箇所に異なる影響を与えます。

  • 分析変換は、レートとひずみの目的のトレードオフを達成する潜在表現を生成するようにトレーニングされます。
  • 合成変換は、特定の潜在表現でひずみを最小化するようにトレーニングされます。
  • 事前確率のパラメータは、特定の潜在表現でレートを最小化するようにトレーニングされます。これは、事前確率を最大尤度において潜在の周辺分布に適合するのと同じです。
def pass_through_loss(_, x):
  # Since rate and distortion are unsupervised, the loss doesn't need a target.
  return x

def make_mnist_compression_trainer(lmbda, latent_dims=50):
  trainer = MNISTCompressionTrainer(latent_dims)
  trainer.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    # Just pass through rate and distortion as losses/metrics.
    loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
    metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
    loss_weights=dict(rate=1., distortion=lmbda),
  )
  return trainer

次に、モデルをトレーニングします。ここでは、画像を圧縮するだけであるため、人間による注釈付けは必要ありません。そのため、map を使って注釈を削除し、レートとひずみの「ダミー」ターゲットを追加します。

def add_rd_targets(image, label):
  # Training is unsupervised, so labels aren't necessary here. However, we
  # need to add "dummy" targets for rate and distortion.
  return image, dict(rate=0., distortion=0.)

def train_mnist_model(lmbda):
  trainer = make_mnist_compression_trainer(lmbda)
  trainer.fit(
      training_dataset.map(add_rd_targets).batch(128).prefetch(8),
      epochs=15,
      validation_data=validation_dataset.map(add_rd_targets).batch(128).cache(),
      validation_freq=1,
      verbose=1,
  )
  return trainer

trainer = train_mnist_model(lmbda=2000)
Epoch 1/15
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
467/469 [============================>.] - ETA: 0s - loss: 217.8335 - distortion_loss: 0.0589 - rate_loss: 100.0286 - distortion_pass_through_loss: 0.0589 - rate_pass_through_loss: 100.0286
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 8s 8ms/step - loss: 217.6788 - distortion_loss: 0.0588 - rate_loss: 99.9971 - distortion_pass_through_loss: 0.0588 - rate_pass_through_loss: 99.9924 - val_loss: 178.6942 - val_distortion_loss: 0.0434 - val_rate_loss: 91.9543 - val_distortion_pass_through_loss: 0.0434 - val_rate_pass_through_loss: 91.9594
Epoch 2/15
469/469 [==============================] - 3s 6ms/step - loss: 166.7406 - distortion_loss: 0.0415 - rate_loss: 83.8010 - distortion_pass_through_loss: 0.0415 - rate_pass_through_loss: 83.7965 - val_loss: 157.1429 - val_distortion_loss: 0.0409 - val_rate_loss: 75.3245 - val_distortion_pass_through_loss: 0.0409 - val_rate_pass_through_loss: 75.3312
Epoch 3/15
469/469 [==============================] - 3s 6ms/step - loss: 151.3889 - distortion_loss: 0.0402 - rate_loss: 70.9442 - distortion_pass_through_loss: 0.0402 - rate_pass_through_loss: 70.9411 - val_loss: 144.4897 - val_distortion_loss: 0.0403 - val_rate_loss: 63.9764 - val_distortion_pass_through_loss: 0.0402 - val_rate_pass_through_loss: 63.9807
Epoch 4/15
469/469 [==============================] - 3s 6ms/step - loss: 142.9433 - distortion_loss: 0.0400 - rate_loss: 63.0224 - distortion_pass_through_loss: 0.0400 - rate_pass_through_loss: 63.0202 - val_loss: 137.4160 - val_distortion_loss: 0.0411 - val_rate_loss: 55.3039 - val_distortion_pass_through_loss: 0.0411 - val_rate_pass_through_loss: 55.2885
Epoch 5/15
469/469 [==============================] - 3s 6ms/step - loss: 137.3771 - distortion_loss: 0.0395 - rate_loss: 58.3224 - distortion_pass_through_loss: 0.0395 - rate_pass_through_loss: 58.3205 - val_loss: 132.2905 - val_distortion_loss: 0.0417 - val_rate_loss: 48.9274 - val_distortion_pass_through_loss: 0.0417 - val_rate_pass_through_loss: 48.9382
Epoch 6/15
469/469 [==============================] - 3s 6ms/step - loss: 133.5226 - distortion_loss: 0.0391 - rate_loss: 55.3185 - distortion_pass_through_loss: 0.0391 - rate_pass_through_loss: 55.3175 - val_loss: 127.0724 - val_distortion_loss: 0.0404 - val_rate_loss: 46.3232 - val_distortion_pass_through_loss: 0.0404 - val_rate_pass_through_loss: 46.3234
Epoch 7/15
469/469 [==============================] - 3s 6ms/step - loss: 130.3693 - distortion_loss: 0.0386 - rate_loss: 53.1581 - distortion_pass_through_loss: 0.0386 - rate_pass_through_loss: 53.1566 - val_loss: 123.5252 - val_distortion_loss: 0.0403 - val_rate_loss: 42.8875 - val_distortion_pass_through_loss: 0.0403 - val_rate_pass_through_loss: 42.8826
Epoch 8/15
469/469 [==============================] - 3s 6ms/step - loss: 128.0058 - distortion_loss: 0.0383 - rate_loss: 51.4280 - distortion_pass_through_loss: 0.0383 - rate_pass_through_loss: 51.4268 - val_loss: 121.3483 - val_distortion_loss: 0.0400 - val_rate_loss: 41.3487 - val_distortion_pass_through_loss: 0.0400 - val_rate_pass_through_loss: 41.3539
Epoch 9/15
469/469 [==============================] - 3s 6ms/step - loss: 125.6857 - distortion_loss: 0.0379 - rate_loss: 49.9369 - distortion_pass_through_loss: 0.0379 - rate_pass_through_loss: 49.9354 - val_loss: 119.4494 - val_distortion_loss: 0.0398 - val_rate_loss: 39.8691 - val_distortion_pass_through_loss: 0.0398 - val_rate_pass_through_loss: 39.8512
Epoch 10/15
469/469 [==============================] - 3s 6ms/step - loss: 123.4883 - distortion_loss: 0.0375 - rate_loss: 48.5796 - distortion_pass_through_loss: 0.0375 - rate_pass_through_loss: 48.5789 - val_loss: 118.5806 - val_distortion_loss: 0.0391 - val_rate_loss: 40.3094 - val_distortion_pass_through_loss: 0.0392 - val_rate_pass_through_loss: 40.3033
Epoch 11/15
469/469 [==============================] - 3s 6ms/step - loss: 121.5731 - distortion_loss: 0.0371 - rate_loss: 47.4418 - distortion_pass_through_loss: 0.0371 - rate_pass_through_loss: 47.4408 - val_loss: 115.8420 - val_distortion_loss: 0.0380 - val_rate_loss: 39.9038 - val_distortion_pass_through_loss: 0.0380 - val_rate_pass_through_loss: 39.8994
Epoch 12/15
469/469 [==============================] - 3s 6ms/step - loss: 119.7753 - distortion_loss: 0.0367 - rate_loss: 46.3968 - distortion_pass_through_loss: 0.0367 - rate_pass_through_loss: 46.3957 - val_loss: 114.8861 - val_distortion_loss: 0.0373 - val_rate_loss: 40.2797 - val_distortion_pass_through_loss: 0.0373 - val_rate_pass_through_loss: 40.2883
Epoch 13/15
469/469 [==============================] - 3s 6ms/step - loss: 118.1635 - distortion_loss: 0.0363 - rate_loss: 45.5972 - distortion_pass_through_loss: 0.0363 - rate_pass_through_loss: 45.5967 - val_loss: 114.0300 - val_distortion_loss: 0.0367 - val_rate_loss: 40.5612 - val_distortion_pass_through_loss: 0.0367 - val_rate_pass_through_loss: 40.5718
Epoch 14/15
469/469 [==============================] - 3s 6ms/step - loss: 116.8593 - distortion_loss: 0.0360 - rate_loss: 44.9107 - distortion_pass_through_loss: 0.0360 - rate_pass_through_loss: 44.9097 - val_loss: 112.6166 - val_distortion_loss: 0.0363 - val_rate_loss: 40.0470 - val_distortion_pass_through_loss: 0.0363 - val_rate_pass_through_loss: 40.0628
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - loss: 115.6814 - distortion_loss: 0.0356 - rate_loss: 44.4095 - distortion_pass_through_loss: 0.0356 - rate_pass_through_loss: 44.4091 - val_loss: 112.2964 - val_distortion_loss: 0.0360 - val_rate_loss: 40.3579 - val_distortion_pass_through_loss: 0.0360 - val_rate_pass_through_loss: 40.3735

MNIST 画像を圧縮する

テスト時の圧縮と解凍用に、トレーニング済みのモデルを以下の 2 つに分割します。

  • エンコーダ側には、分析変換とエントロピーモデルが含まれます。
  • デコーダ側には、合成変換と同じエントロピーモデルが含まれます。

テスト時には、潜在に追加ノイズが含まれませんが、量子化されてから非可逆的に圧縮されるため、それらに新しい名前を指定します。それらと再構築の \(\hat x\) と \(\hat y\) をそれぞれに呼び出します(『End-to-end Optimized Image Compression』に従います)。

class MNISTCompressor(tf.keras.Model):
  """Compresses MNIST images to strings."""

  def __init__(self, analysis_transform, entropy_model):
    super().__init__()
    self.analysis_transform = analysis_transform
    self.entropy_model = entropy_model

  def call(self, x):
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    y = self.analysis_transform(x)
    # Also return the exact information content of each digit.
    _, bits = self.entropy_model(y, training=False)
    return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
  """Decompresses MNIST images from strings."""

  def __init__(self, entropy_model, synthesis_transform):
    super().__init__()
    self.entropy_model = entropy_model
    self.synthesis_transform = synthesis_transform

  def call(self, string):
    y_hat = self.entropy_model.decompress(string, ())
    x_hat = self.synthesis_transform(y_hat)
    # Scale and cast back to 8-bit integer.
    return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)

compression=True でインスタンス化すると、エントロピーモデルは、学習した事前確率をレンジコーディングアルゴリズムのテーブルに変換します。compress() を呼び出すと、このアルゴリズムが呼び出され、潜在空間ベクトルをビットシーケンスに変換します。各バイナリ文字列の長さは、潜在の情報コンテンツに近似します(事前確率の下の潜在の負の対数尤度)。

圧縮と解凍のエントロピーモデルは、同じインスタンスである必要があります。これは、レンジコーディングテーブルが両側でまったく同じである必要があるためです。そうでない場合、解凍エラーが発生します。

def make_mnist_codec(trainer, **kwargs):
  # The entropy model must be created with `compression=True` and the same
  # instance must be shared between compressor and decompressor.
  entropy_model = tfc.ContinuousBatchedEntropyModel(
      trainer.prior, coding_rank=1, compression=True, **kwargs)
  compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
  decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
  return compressor, decompressor

compressor, decompressor = make_mnist_codec(trainer)

検証データセットから 16 個の画像を取得します。skip の引数を変えることで、さまざまなサブセットを選択できます。

(originals, _), = validation_dataset.batch(16).skip(3).take(1)

これらを文字列に圧縮し、それぞれの情報コンテンツをビットで追跡します。

strings, entropies = compressor(originals)

print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x39c3f87dec58
Number of bits actually needed to represent it: 44.04

画像を文字列から解凍します。

reconstructions = decompressor(strings)

各 16 個の元の数字を圧縮されたバイナリ表現と再構築された数字と共に表示します。

display_digits(originals, strings, entropies, reconstructions)

png

エンコードされた文字列の長さが各数字の情報コンテンツと異なることに注目してください。

これは、レンジコーディングプロセスが離散確率を使用し、少量のオーバーヘッドがあるためです。そのため、特に短い文字列の場合、対応するものは近似でしかありません。ただし、レンジコーディングは漸近的に最適です。限界では、予想されるビット数は、トレーニングモデルのレート項が上限であるクロスエントロピー(予想される情報コンテンツ)に近づきます。

レートとひずみのトレードオフ

上記では、モデルは、各数字を表現するために使用されるビットの平均数と再構築で発生した誤差の間の特定のトレードオフのためにトレーニングされました(lmbda=2000 で指定)。

異なる値でこの実験を繰り返した場合はどうなるでしょうか?

まずは、\(\lambda\) を 500 に減らしてみましょう。

def train_and_visualize_model(lmbda):
  trainer = train_mnist_model(lmbda=lmbda)
  compressor, decompressor = make_mnist_codec(trainer)
  strings, entropies = compressor(originals)
  reconstructions = decompressor(strings)
  display_digits(originals, strings, entropies, reconstructions)

train_and_visualize_model(lmbda=500)
Epoch 1/15
465/469 [============================>.] - ETA: 0s - loss: 127.4447 - distortion_loss: 0.0695 - rate_loss: 92.6941 - distortion_pass_through_loss: 0.0695 - rate_pass_through_loss: 92.6941
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 6s 6ms/step - loss: 127.2831 - distortion_loss: 0.0694 - rate_loss: 92.5993 - distortion_pass_through_loss: 0.0694 - rate_pass_through_loss: 92.5930 - val_loss: 107.4028 - val_distortion_loss: 0.0549 - val_rate_loss: 79.9750 - val_distortion_pass_through_loss: 0.0549 - val_rate_pass_through_loss: 79.9802
Epoch 2/15
469/469 [==============================] - 3s 6ms/step - loss: 97.1724 - distortion_loss: 0.0538 - rate_loss: 70.2783 - distortion_pass_through_loss: 0.0538 - rate_pass_through_loss: 70.2729 - val_loss: 86.0219 - val_distortion_loss: 0.0594 - val_rate_loss: 56.3162 - val_distortion_pass_through_loss: 0.0594 - val_rate_pass_through_loss: 56.3187
Epoch 3/15
469/469 [==============================] - 3s 6ms/step - loss: 81.1227 - distortion_loss: 0.0560 - rate_loss: 53.1034 - distortion_pass_through_loss: 0.0560 - rate_pass_through_loss: 53.0995 - val_loss: 71.9202 - val_distortion_loss: 0.0686 - val_rate_loss: 37.5965 - val_distortion_pass_through_loss: 0.0687 - val_rate_pass_through_loss: 37.5954
Epoch 4/15
469/469 [==============================] - 3s 6ms/step - loss: 71.5995 - distortion_loss: 0.0595 - rate_loss: 41.8626 - distortion_pass_through_loss: 0.0595 - rate_pass_through_loss: 41.8596 - val_loss: 64.1463 - val_distortion_loss: 0.0786 - val_rate_loss: 24.8592 - val_distortion_pass_through_loss: 0.0786 - val_rate_pass_through_loss: 24.8602
Epoch 5/15
469/469 [==============================] - 3s 6ms/step - loss: 66.1026 - distortion_loss: 0.0624 - rate_loss: 34.8940 - distortion_pass_through_loss: 0.0624 - rate_pass_through_loss: 34.8927 - val_loss: 58.4913 - val_distortion_loss: 0.0795 - val_rate_loss: 18.7568 - val_distortion_pass_through_loss: 0.0795 - val_rate_pass_through_loss: 18.7560
Epoch 6/15
469/469 [==============================] - 3s 6ms/step - loss: 62.6672 - distortion_loss: 0.0646 - rate_loss: 30.3623 - distortion_pass_through_loss: 0.0646 - rate_pass_through_loss: 30.3613 - val_loss: 54.9740 - val_distortion_loss: 0.0818 - val_rate_loss: 14.0641 - val_distortion_pass_through_loss: 0.0818 - val_rate_pass_through_loss: 14.0646
Epoch 7/15
469/469 [==============================] - 3s 6ms/step - loss: 60.1863 - distortion_loss: 0.0660 - rate_loss: 27.2017 - distortion_pass_through_loss: 0.0660 - rate_pass_through_loss: 27.2010 - val_loss: 52.4609 - val_distortion_loss: 0.0806 - val_rate_loss: 12.1524 - val_distortion_pass_through_loss: 0.0806 - val_rate_pass_through_loss: 12.1531
Epoch 8/15
469/469 [==============================] - 3s 6ms/step - loss: 58.0571 - distortion_loss: 0.0665 - rate_loss: 24.8082 - distortion_pass_through_loss: 0.0665 - rate_pass_through_loss: 24.8073 - val_loss: 49.9638 - val_distortion_loss: 0.0771 - val_rate_loss: 11.4078 - val_distortion_pass_through_loss: 0.0771 - val_rate_pass_through_loss: 11.4103
Epoch 9/15
469/469 [==============================] - 3s 6ms/step - loss: 56.1462 - distortion_loss: 0.0665 - rate_loss: 22.8890 - distortion_pass_through_loss: 0.0665 - rate_pass_through_loss: 22.8888 - val_loss: 48.1192 - val_distortion_loss: 0.0704 - val_rate_loss: 12.9410 - val_distortion_pass_through_loss: 0.0704 - val_rate_pass_through_loss: 12.9476
Epoch 10/15
469/469 [==============================] - 3s 6ms/step - loss: 54.1863 - distortion_loss: 0.0657 - rate_loss: 21.3211 - distortion_pass_through_loss: 0.0657 - rate_pass_through_loss: 21.3206 - val_loss: 47.0492 - val_distortion_loss: 0.0674 - val_rate_loss: 13.3331 - val_distortion_pass_through_loss: 0.0674 - val_rate_pass_through_loss: 13.3350
Epoch 11/15
469/469 [==============================] - 3s 6ms/step - loss: 52.4151 - distortion_loss: 0.0647 - rate_loss: 20.0704 - distortion_pass_through_loss: 0.0647 - rate_pass_through_loss: 20.0700 - val_loss: 46.5608 - val_distortion_loss: 0.0665 - val_rate_loss: 13.2897 - val_distortion_pass_through_loss: 0.0665 - val_rate_pass_through_loss: 13.2926
Epoch 12/15
469/469 [==============================] - 3s 6ms/step - loss: 50.9138 - distortion_loss: 0.0636 - rate_loss: 19.1121 - distortion_pass_through_loss: 0.0636 - rate_pass_through_loss: 19.1114 - val_loss: 45.9211 - val_distortion_loss: 0.0645 - val_rate_loss: 13.6699 - val_distortion_pass_through_loss: 0.0645 - val_rate_pass_through_loss: 13.6701
Epoch 13/15
469/469 [==============================] - 3s 6ms/step - loss: 49.7118 - distortion_loss: 0.0626 - rate_loss: 18.4105 - distortion_pass_through_loss: 0.0626 - rate_pass_through_loss: 18.4100 - val_loss: 45.6058 - val_distortion_loss: 0.0628 - val_rate_loss: 14.1970 - val_distortion_pass_through_loss: 0.0628 - val_rate_pass_through_loss: 14.1988
Epoch 14/15
469/469 [==============================] - 3s 6ms/step - loss: 48.7698 - distortion_loss: 0.0617 - rate_loss: 17.9120 - distortion_pass_through_loss: 0.0617 - rate_pass_through_loss: 17.9119 - val_loss: 45.1903 - val_distortion_loss: 0.0612 - val_rate_loss: 14.5991 - val_distortion_pass_through_loss: 0.0612 - val_rate_pass_through_loss: 14.6004
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - loss: 48.0780 - distortion_loss: 0.0611 - rate_loss: 17.5208 - distortion_pass_through_loss: 0.0611 - rate_pass_through_loss: 17.5206 - val_loss: 45.0955 - val_distortion_loss: 0.0615 - val_rate_loss: 14.3562 - val_distortion_pass_through_loss: 0.0615 - val_rate_pass_through_loss: 14.3626

png

コードのビットレートが下がり、数字の信頼性も下がります。ただし、ほとんどの数字は認識可能のままです。

さらに \(\lambda\) を減らしてみましょう。

train_and_visualize_model(lmbda=300)
Epoch 1/15
466/469 [============================>.] - ETA: 0s - loss: 113.7090 - distortion_loss: 0.0753 - rate_loss: 91.1310 - distortion_pass_through_loss: 0.0753 - rate_pass_through_loss: 91.1310
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 6s 6ms/step - loss: 113.6090 - distortion_loss: 0.0752 - rate_loss: 91.0583 - distortion_pass_through_loss: 0.0752 - rate_pass_through_loss: 91.0516 - val_loss: 96.5233 - val_distortion_loss: 0.0679 - val_rate_loss: 76.1659 - val_distortion_pass_through_loss: 0.0679 - val_rate_pass_through_loss: 76.1617
Epoch 2/15
469/469 [==============================] - 3s 6ms/step - loss: 85.8572 - distortion_loss: 0.0613 - rate_loss: 67.4572 - distortion_pass_through_loss: 0.0613 - rate_pass_through_loss: 67.4516 - val_loss: 74.5242 - val_distortion_loss: 0.0793 - val_rate_loss: 50.7241 - val_distortion_pass_through_loss: 0.0793 - val_rate_pass_through_loss: 50.7321
Epoch 3/15
469/469 [==============================] - 3s 6ms/step - loss: 68.9031 - distortion_loss: 0.0650 - rate_loss: 49.4174 - distortion_pass_through_loss: 0.0650 - rate_pass_through_loss: 49.4135 - val_loss: 59.5138 - val_distortion_loss: 0.0954 - val_rate_loss: 30.8864 - val_distortion_pass_through_loss: 0.0954 - val_rate_pass_through_loss: 30.8949
Epoch 4/15
469/469 [==============================] - 3s 6ms/step - loss: 58.3479 - distortion_loss: 0.0696 - rate_loss: 37.4610 - distortion_pass_through_loss: 0.0696 - rate_pass_through_loss: 37.4585 - val_loss: 49.3487 - val_distortion_loss: 0.1029 - val_rate_loss: 18.4801 - val_distortion_pass_through_loss: 0.1028 - val_rate_pass_through_loss: 18.4910
Epoch 5/15
469/469 [==============================] - 3s 6ms/step - loss: 52.0953 - distortion_loss: 0.0740 - rate_loss: 29.8993 - distortion_pass_through_loss: 0.0740 - rate_pass_through_loss: 29.8978 - val_loss: 42.9612 - val_distortion_loss: 0.1054 - val_rate_loss: 11.3369 - val_distortion_pass_through_loss: 0.1054 - val_rate_pass_through_loss: 11.3445
Epoch 6/15
469/469 [==============================] - 3s 6ms/step - loss: 48.1743 - distortion_loss: 0.0775 - rate_loss: 24.9172 - distortion_pass_through_loss: 0.0775 - rate_pass_through_loss: 24.9160 - val_loss: 38.8429 - val_distortion_loss: 0.1035 - val_rate_loss: 7.7809 - val_distortion_pass_through_loss: 0.1035 - val_rate_pass_through_loss: 7.7837
Epoch 7/15
469/469 [==============================] - 3s 6ms/step - loss: 45.4033 - distortion_loss: 0.0800 - rate_loss: 21.4013 - distortion_pass_through_loss: 0.0800 - rate_pass_through_loss: 21.4004 - val_loss: 36.4476 - val_distortion_loss: 0.1025 - val_rate_loss: 5.7000 - val_distortion_pass_through_loss: 0.1025 - val_rate_pass_through_loss: 5.7030
Epoch 8/15
469/469 [==============================] - 3s 6ms/step - loss: 43.1902 - distortion_loss: 0.0815 - rate_loss: 18.7450 - distortion_pass_through_loss: 0.0815 - rate_pass_through_loss: 18.7442 - val_loss: 34.4560 - val_distortion_loss: 0.0938 - val_rate_loss: 6.3266 - val_distortion_pass_through_loss: 0.0938 - val_rate_pass_through_loss: 6.3243
Epoch 9/15
469/469 [==============================] - 3s 6ms/step - loss: 41.1994 - distortion_loss: 0.0816 - rate_loss: 16.7293 - distortion_pass_through_loss: 0.0816 - rate_pass_through_loss: 16.7284 - val_loss: 33.6424 - val_distortion_loss: 0.0906 - val_rate_loss: 6.4604 - val_distortion_pass_through_loss: 0.0906 - val_rate_pass_through_loss: 6.4591
Epoch 10/15
469/469 [==============================] - 3s 6ms/step - loss: 39.5689 - distortion_loss: 0.0811 - rate_loss: 15.2472 - distortion_pass_through_loss: 0.0811 - rate_pass_through_loss: 15.2467 - val_loss: 32.8275 - val_distortion_loss: 0.0851 - val_rate_loss: 7.3065 - val_distortion_pass_through_loss: 0.0851 - val_rate_pass_through_loss: 7.3087
Epoch 11/15
469/469 [==============================] - 3s 6ms/step - loss: 38.1226 - distortion_loss: 0.0800 - rate_loss: 14.1267 - distortion_pass_through_loss: 0.0800 - rate_pass_through_loss: 14.1260 - val_loss: 32.5285 - val_distortion_loss: 0.0841 - val_rate_loss: 7.2859 - val_distortion_pass_through_loss: 0.0841 - val_rate_pass_through_loss: 7.2903
Epoch 12/15
469/469 [==============================] - 3s 6ms/step - loss: 36.8895 - distortion_loss: 0.0786 - rate_loss: 13.3042 - distortion_pass_through_loss: 0.0786 - rate_pass_through_loss: 13.3038 - val_loss: 32.3595 - val_distortion_loss: 0.0830 - val_rate_loss: 7.4672 - val_distortion_pass_through_loss: 0.0830 - val_rate_pass_through_loss: 7.4686
Epoch 13/15
469/469 [==============================] - 3s 6ms/step - loss: 35.9890 - distortion_loss: 0.0778 - rate_loss: 12.6432 - distortion_pass_through_loss: 0.0778 - rate_pass_through_loss: 12.6427 - val_loss: 32.1735 - val_distortion_loss: 0.0807 - val_rate_loss: 7.9718 - val_distortion_pass_through_loss: 0.0807 - val_rate_pass_through_loss: 7.9728
Epoch 14/15
469/469 [==============================] - 3s 6ms/step - loss: 35.2725 - distortion_loss: 0.0771 - rate_loss: 12.1506 - distortion_pass_through_loss: 0.0771 - rate_pass_through_loss: 12.1504 - val_loss: 31.8718 - val_distortion_loss: 0.0777 - val_rate_loss: 8.5510 - val_distortion_pass_through_loss: 0.0778 - val_rate_pass_through_loss: 8.5569
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - loss: 34.6280 - distortion_loss: 0.0764 - rate_loss: 11.6954 - distortion_pass_through_loss: 0.0764 - rate_pass_through_loss: 11.6957 - val_loss: 31.8556 - val_distortion_loss: 0.0772 - val_rate_loss: 8.7024 - val_distortion_pass_through_loss: 0.0772 - val_rate_pass_through_loss: 8.7201

png

数字当たり 1 バイトの順に、文字列がさらに短くなり始めました。ただし、これにはコストが伴い、さらに多くの数字が認識できなくなってしまいました。

これは、このモデルが人間による誤差の認識に左右されず、ピクセル値の観点で絶対偏差を測定していることを示します。画像の品質をさらに高めるには、ピクセル損失を知覚損失に置き換える必要があります。

デコーダーを生成モデルとして使用する

デコーダーにランダムなビットを供給すると、これは、モデルが数字を表すことを学習した分布から効果的にサンプリングされます。

まず、入力文字列が完全にデコードされていないかどうかを検出するサニティチェックを行わずに、コンプレッサー/デコンプレッサーを再インスタンス化します。

compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)

次に、十分な長さのランダムな文字列をデコンプレッサに入力して、それらから数字をデコード/サンプリングできるようにします。

import os

strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)

fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
  axes[i].imshow(tf.squeeze(samples[i]))
  axes[i].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

png