Learned data compression

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook shows how to do lossy data compression using neural networks and TensorFlow Compression.

Lossy compression involves making a trade-off between rate, the expected number of bits needed to encode a sample, and distortion, the expected error in the reconstruction of the sample.

The examples below use an autoencoder-like model to compress images from the MNIST dataset. The method is based on the paper End-to-end Optimized Image Compression.

More background on learned data compression can be found in this paper targeted at people familiar with classical data compression, or this survey targeted at a machine learning audience.

Setup

Install Tensorflow Compression via pip.

# Installs the latest version of TFC compatible with the installed TF version.

read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.17.0 requires tensorflow<2.18,>=2.17, but you have tensorflow 2.14.1 which is incompatible.

Import library dependencies.

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
2024-07-19 01:53:11.077097: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-19 01:53:11.077144: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-19 01:53:11.077190: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Define the trainer model.

Because the model resembles an autoencoder, and we need to perform a different set of functions during training and inference, the setup is a little different from, say, a classifier.

The training model consists of three parts:

  • the analysis (or encoder) transform, converting from the image into a latent space,
  • the synthesis (or decoder) transform, converting from the latent space back into image space, and
  • a prior and entropy model, modeling the marginal probabilities of the latents.

First, define the transforms:

def make_analysis_transform(latent_dims):
  """Creates the analysis (encoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Conv2D(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2D(
          50, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          latent_dims, use_bias=True, activation=None, name="fc_2"),
  ], name="analysis_transform")
def make_synthesis_transform():
  """Creates the synthesis (decoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          2450, use_bias=True, activation="leaky_relu", name="fc_2"),
      tf.keras.layers.Reshape((7, 7, 50)),
      tf.keras.layers.Conv2DTranspose(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2DTranspose(
          1, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
  ], name="synthesis_transform")

The trainer holds an instance of both transforms, as well as the parameters of the prior.

Its call method is set up to compute:

  • rate, an estimate of the number of bits needed to represent the batch of digits, and
  • distortion, the mean absolute difference between the pixels of the original digits and their reconstructions.
class MNISTCompressionTrainer(tf.keras.Model):
  """Model that trains a compressor/decompressor for MNIST."""

  def __init__(self, latent_dims):
    super().__init__()
    self.analysis_transform = make_analysis_transform(latent_dims)
    self.synthesis_transform = make_synthesis_transform()
    self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))

  @property
  def prior(self):
    return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))

  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    x = tf.reshape(x, (-1, 28, 28, 1))

    # Compute latent space representation y, perturb it and model its entropy,
    # then compute the reconstructed pixel-level representation x_hat.
    y = self.analysis_transform(x)
    entropy_model = tfc.ContinuousBatchedEntropyModel(
        self.prior, coding_rank=1, compression=False)
    y_tilde, rate = entropy_model(y, training=training)
    x_tilde = self.synthesis_transform(y_tilde)

    # Average number of bits per MNIST digit.
    rate = tf.reduce_mean(rate)

    # Mean absolute difference across pixels.
    distortion = tf.reduce_mean(abs(x - x_tilde))

    return dict(rate=rate, distortion=distortion)

Compute rate and distortion.

Let's walk through this step by step, using one image from the training set. Load the MNIST dataset for training and validation:

training_dataset, validation_dataset = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=False,
)
2024-07-19 01:53:15.049496: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

And extract one image \(x\):

(x, _), = validation_dataset.take(1)

plt.imshow(tf.squeeze(x))
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)
2024-07-19 01:53:15.383276: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

To get the latent representation \(y\), we need to cast it to float32, add a batch dimension, and pass it through the analysis transform.

x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)

print("y:", y)
y: tf.Tensor(
[[-0.01224455 -0.09719235 -0.08213592  0.0354024  -0.01443382  0.02162577
   0.02967148  0.00232092  0.00181769  0.00430147]], shape=(1, 10), dtype=float32)

The latents will be quantized at test time. To model this in a differentiable way during training, we add uniform noise in the interval \((-.5, .5)\) and call the result \(\tilde y\). This is the same terminology as used in the paper End-to-end Optimized Image Compression.

y_tilde = y + tf.random.uniform(y.shape, -.5, .5)

print("y_tilde:", y_tilde)
y_tilde: tf.Tensor(
[[ 0.38850513 -0.3627419  -0.10774756  0.03274522 -0.4638402   0.2084787
   0.00465386  0.45559373  0.45222282 -0.46006057]], shape=(1, 10), dtype=float32)

The "prior" is a probability density that we train to model the marginal distribution of the noisy latents. For example, it could be a set of independent logistic distributions with different scales for each latent dimension. tfc.NoisyLogistic accounts for the fact that the latents have additive noise. As the scale approaches zero, a logistic distribution approaches a dirac delta (spike), but the added noise causes the "noisy" distribution to approach the uniform distribution instead.

prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))

_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));

png

During training, tfc.ContinuousBatchedEntropyModel adds uniform noise, and uses the noise and the prior to compute a (differentiable) upper bound on the rate (the average number of bits necessary to encode the latent representation). That bound can be minimized as a loss.

entropy_model = tfc.ContinuousBatchedEntropyModel(
    prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)

print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.526083], shape=(1,), dtype=float32)
y_tilde: tf.Tensor(
[[ 0.0090554  -0.38909417 -0.4069785   0.18274103  0.2406526  -0.11575054
   0.28057152  0.30737367 -0.13117756 -0.22494133]], shape=(1, 10), dtype=float32)

Lastly, the noisy latents are passed back through the synthesis transform to produce an image reconstruction \(\tilde x\). Distortion is the error between original image and reconstruction. Obviously, with the transforms untrained, the reconstruction is not very useful.

x_tilde = make_synthesis_transform()(y_tilde)

# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)

x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
plt.imshow(tf.squeeze(x_tilde))
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17078552, shape=(), dtype=float32)
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)

png

For every batch of digits, calling the MNISTCompressionTrainer produces the rate and distortion as an average over that batch:

(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)

print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate:  tf.Tensor(20.296253, shape=(), dtype=float32)
distortion:  tf.Tensor(0.14659302, shape=(), dtype=float32)
2024-07-19 01:53:16.195986: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

In the next section, we set up the model to do gradient descent on these two losses.

Train the model.

We compile the trainer in a way that it optimizes the rate–distortion Lagrangian, that is, a sum of rate and distortion, where one of the terms is weighted by Lagrange parameter \(\lambda\).

This loss function affects the different parts of the model differently:

  • The analysis transform is trained to produce a latent representation that achieves the desired trade-off between rate and distortion.
  • The synthesis transform is trained to minimize distortion, given the latent representation.
  • The parameters of the prior are trained to minimize the rate given the latent representation. This is identical to fitting the prior to the marginal distribution of latents in a maximum likelihood sense.
def pass_through_loss(_, x):
  # Since rate and distortion are unsupervised, the loss doesn't need a target.
  return x

def make_mnist_compression_trainer(lmbda, latent_dims=50):
  trainer = MNISTCompressionTrainer(latent_dims)
  trainer.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    # Just pass through rate and distortion as losses/metrics.
    loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
    metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
    loss_weights=dict(rate=1., distortion=lmbda),
  )
  return trainer

Next, train the model. The human annotations are not necessary here, since we just want to compress the images, so we drop them using a map and instead add "dummy" targets for rate and distortion.

def add_rd_targets(image, label):
  # Training is unsupervised, so labels aren't necessary here. However, we
  # need to add "dummy" targets for rate and distortion.
  return image, dict(rate=0., distortion=0.)

def train_mnist_model(lmbda):
  trainer = make_mnist_compression_trainer(lmbda)
  trainer.fit(
      training_dataset.map(add_rd_targets).batch(128).prefetch(8),
      epochs=15,
      validation_data=validation_dataset.map(add_rd_targets).batch(128).cache(),
      validation_freq=1,
      verbose=1,
  )
  return trainer

trainer = train_mnist_model(lmbda=2000)
Epoch 1/15
469/469 [==============================] - ETA: 0s - loss: 216.9254 - distortion_loss: 0.0584 - rate_loss: 100.1568 - distortion_pass_through_loss: 0.0584 - rate_pass_through_loss: 100.1521
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 13s 22ms/step - loss: 216.9254 - distortion_loss: 0.0584 - rate_loss: 100.1568 - distortion_pass_through_loss: 0.0584 - rate_pass_through_loss: 100.1521 - val_loss: 176.1546 - val_distortion_loss: 0.0419 - val_rate_loss: 92.4025 - val_distortion_pass_through_loss: 0.0419 - val_rate_pass_through_loss: 92.4103
Epoch 2/15
469/469 [==============================] - 10s 20ms/step - loss: 165.6864 - distortion_loss: 0.0409 - rate_loss: 83.9722 - distortion_pass_through_loss: 0.0409 - rate_pass_through_loss: 83.9679 - val_loss: 155.8378 - val_distortion_loss: 0.0399 - val_rate_loss: 76.1234 - val_distortion_pass_through_loss: 0.0399 - val_rate_pass_through_loss: 76.1243
Epoch 3/15
469/469 [==============================] - 10s 20ms/step - loss: 150.9844 - distortion_loss: 0.0400 - rate_loss: 71.0840 - distortion_pass_through_loss: 0.0399 - rate_pass_through_loss: 71.0809 - val_loss: 144.5865 - val_distortion_loss: 0.0402 - val_rate_loss: 64.2065 - val_distortion_pass_through_loss: 0.0402 - val_rate_pass_through_loss: 64.2150
Epoch 4/15
469/469 [==============================] - 9s 20ms/step - loss: 142.6878 - distortion_loss: 0.0398 - rate_loss: 63.0357 - distortion_pass_through_loss: 0.0398 - rate_pass_through_loss: 63.0338 - val_loss: 136.4241 - val_distortion_loss: 0.0403 - val_rate_loss: 55.7424 - val_distortion_pass_through_loss: 0.0403 - val_rate_pass_through_loss: 55.7691
Epoch 5/15
469/469 [==============================] - 9s 20ms/step - loss: 137.4838 - distortion_loss: 0.0396 - rate_loss: 58.2308 - distortion_pass_through_loss: 0.0396 - rate_pass_through_loss: 58.2295 - val_loss: 132.1830 - val_distortion_loss: 0.0412 - val_rate_loss: 49.8589 - val_distortion_pass_through_loss: 0.0412 - val_rate_pass_through_loss: 49.8711
Epoch 6/15
469/469 [==============================] - 9s 20ms/step - loss: 133.8861 - distortion_loss: 0.0394 - rate_loss: 55.1402 - distortion_pass_through_loss: 0.0394 - rate_pass_through_loss: 55.1388 - val_loss: 127.9782 - val_distortion_loss: 0.0415 - val_rate_loss: 45.0345 - val_distortion_pass_through_loss: 0.0415 - val_rate_pass_through_loss: 45.0470
Epoch 7/15
469/469 [==============================] - 9s 20ms/step - loss: 130.8392 - distortion_loss: 0.0389 - rate_loss: 52.9458 - distortion_pass_through_loss: 0.0389 - rate_pass_through_loss: 52.9443 - val_loss: 124.2168 - val_distortion_loss: 0.0408 - val_rate_loss: 42.7138 - val_distortion_pass_through_loss: 0.0408 - val_rate_pass_through_loss: 42.7179
Epoch 8/15
469/469 [==============================] - 9s 20ms/step - loss: 128.3935 - distortion_loss: 0.0386 - rate_loss: 51.1929 - distortion_pass_through_loss: 0.0386 - rate_pass_through_loss: 51.1917 - val_loss: 121.7899 - val_distortion_loss: 0.0406 - val_rate_loss: 40.6796 - val_distortion_pass_through_loss: 0.0405 - val_rate_pass_through_loss: 40.6837
Epoch 9/15
469/469 [==============================] - 9s 20ms/step - loss: 125.8994 - distortion_loss: 0.0381 - rate_loss: 49.6562 - distortion_pass_through_loss: 0.0381 - rate_pass_through_loss: 49.6556 - val_loss: 119.4453 - val_distortion_loss: 0.0391 - val_rate_loss: 41.2929 - val_distortion_pass_through_loss: 0.0391 - val_rate_pass_through_loss: 41.3038
Epoch 10/15
469/469 [==============================] - 9s 20ms/step - loss: 123.7347 - distortion_loss: 0.0377 - rate_loss: 48.2651 - distortion_pass_through_loss: 0.0377 - rate_pass_through_loss: 48.2641 - val_loss: 117.1507 - val_distortion_loss: 0.0387 - val_rate_loss: 39.7558 - val_distortion_pass_through_loss: 0.0387 - val_rate_pass_through_loss: 39.7709
Epoch 11/15
469/469 [==============================] - 9s 20ms/step - loss: 121.6866 - distortion_loss: 0.0373 - rate_loss: 47.0960 - distortion_pass_through_loss: 0.0373 - rate_pass_through_loss: 47.0950 - val_loss: 115.9093 - val_distortion_loss: 0.0379 - val_rate_loss: 40.1509 - val_distortion_pass_through_loss: 0.0379 - val_rate_pass_through_loss: 40.1661
Epoch 12/15
469/469 [==============================] - 9s 20ms/step - loss: 119.8163 - distortion_loss: 0.0369 - rate_loss: 46.0885 - distortion_pass_through_loss: 0.0369 - rate_pass_through_loss: 46.0875 - val_loss: 115.0018 - val_distortion_loss: 0.0372 - val_rate_loss: 40.5832 - val_distortion_pass_through_loss: 0.0372 - val_rate_pass_through_loss: 40.6000
Epoch 13/15
469/469 [==============================] - 9s 20ms/step - loss: 118.4900 - distortion_loss: 0.0366 - rate_loss: 45.2273 - distortion_pass_through_loss: 0.0366 - rate_pass_through_loss: 45.2263 - val_loss: 113.9260 - val_distortion_loss: 0.0372 - val_rate_loss: 39.5021 - val_distortion_pass_through_loss: 0.0372 - val_rate_pass_through_loss: 39.5158
Epoch 14/15
469/469 [==============================] - 9s 19ms/step - loss: 116.9452 - distortion_loss: 0.0362 - rate_loss: 44.5765 - distortion_pass_through_loss: 0.0362 - rate_pass_through_loss: 44.5763 - val_loss: 113.4185 - val_distortion_loss: 0.0365 - val_rate_loss: 40.5167 - val_distortion_pass_through_loss: 0.0365 - val_rate_pass_through_loss: 40.5147
Epoch 15/15
469/469 [==============================] - 9s 19ms/step - loss: 115.8957 - distortion_loss: 0.0359 - rate_loss: 44.0527 - distortion_pass_through_loss: 0.0359 - rate_pass_through_loss: 44.0523 - val_loss: 111.9197 - val_distortion_loss: 0.0360 - val_rate_loss: 39.8947 - val_distortion_pass_through_loss: 0.0360 - val_rate_pass_through_loss: 39.8986

Compress some MNIST images.

For compression and decompression at test time, we split the trained model in two parts:

  • The encoder side consists of the analysis transform and the entropy model.
  • The decoder side consists of the synthesis transform and the same entropy model.

At test time, the latents will not have additive noise, but they will be quantized and then losslessly compressed, so we give them new names. We call them and the image reconstruction \(\hat x\) and \(\hat y\), respectively (following End-to-end Optimized Image Compression).

class MNISTCompressor(tf.keras.Model):
  """Compresses MNIST images to strings."""

  def __init__(self, analysis_transform, entropy_model):
    super().__init__()
    self.analysis_transform = analysis_transform
    self.entropy_model = entropy_model

  def call(self, x):
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    y = self.analysis_transform(x)
    # Also return the exact information content of each digit.
    _, bits = self.entropy_model(y, training=False)
    return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
  """Decompresses MNIST images from strings."""

  def __init__(self, entropy_model, synthesis_transform):
    super().__init__()
    self.entropy_model = entropy_model
    self.synthesis_transform = synthesis_transform

  def call(self, string):
    y_hat = self.entropy_model.decompress(string, ())
    x_hat = self.synthesis_transform(y_hat)
    # Scale and cast back to 8-bit integer.
    return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)

When instantiated with compression=True, the entropy model converts the learned prior into tables for a range coding algorithm. When calling compress(), this algorithm is invoked to convert the latent space vector into bit sequences. The length of each binary string approximates the information content of the latent (the negative log likelihood of the latent under the prior).

The entropy model for compression and decompression must be the same instance, because the range coding tables need to be exactly identical on both sides. Otherwise, decoding errors can occur.

def make_mnist_codec(trainer, **kwargs):
  # The entropy model must be created with `compression=True` and the same
  # instance must be shared between compressor and decompressor.
  entropy_model = tfc.ContinuousBatchedEntropyModel(
      trainer.prior, coding_rank=1, compression=True, **kwargs)
  compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
  decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
  return compressor, decompressor

compressor, decompressor = make_mnist_codec(trainer)

Grab 16 images from the validation dataset. You can select a different subset by changing the argument to skip.

(originals, _), = validation_dataset.batch(16).skip(3).take(1)

Compress them to strings, and keep track of each of their information content in bits.

strings, entropies = compressor(originals)

print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x0d1866b8f1
Number of bits actually needed to represent it: 37.49

Decompress the images back from the strings.

reconstructions = decompressor(strings)

Display each of the 16 original digits together with its compressed binary representation, and the reconstructed digit.

display_digits(originals, strings, entropies, reconstructions)

png

Note that the length of the encoded string differs from the information content of each digit.

This is because the range coding process works with discrete probabilities, and has a small amount of overhead. So, especially for short strings, the correspondence is only approximate. However, range coding is asymptotically optimal: in the limit, the expected bit count will approach the cross entropy (the expected information content), for which the rate term in the training model is an upper bound.

The rate–distortion trade-off

Above, the model was trained for a specific trade-off (given by lmbda=2000) between the average number of bits used to represent each digit and the incurred error in the reconstruction.

What happens when we repeat the experiment with different values?

Let's start by reducing \(\lambda\) to 500.

def train_and_visualize_model(lmbda):
  trainer = train_mnist_model(lmbda=lmbda)
  compressor, decompressor = make_mnist_codec(trainer)
  strings, entropies = compressor(originals)
  reconstructions = decompressor(strings)
  display_digits(originals, strings, entropies, reconstructions)

train_and_visualize_model(lmbda=500)
Epoch 1/15
469/469 [==============================] - ETA: 0s - loss: 127.8276 - distortion_loss: 0.0705 - rate_loss: 92.5567 - distortion_pass_through_loss: 0.0705 - rate_pass_through_loss: 92.5504
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 12s 21ms/step - loss: 127.8276 - distortion_loss: 0.0705 - rate_loss: 92.5567 - distortion_pass_through_loss: 0.0705 - rate_pass_through_loss: 92.5504 - val_loss: 108.3089 - val_distortion_loss: 0.0574 - val_rate_loss: 79.6299 - val_distortion_pass_through_loss: 0.0574 - val_rate_pass_through_loss: 79.6331
Epoch 2/15
469/469 [==============================] - 9s 20ms/step - loss: 97.6987 - distortion_loss: 0.0548 - rate_loss: 70.2879 - distortion_pass_through_loss: 0.0548 - rate_pass_through_loss: 70.2826 - val_loss: 86.5739 - val_distortion_loss: 0.0598 - val_rate_loss: 56.6536 - val_distortion_pass_through_loss: 0.0598 - val_rate_pass_through_loss: 56.6590
Epoch 3/15
469/469 [==============================] - 9s 20ms/step - loss: 81.6310 - distortion_loss: 0.0570 - rate_loss: 53.1313 - distortion_pass_through_loss: 0.0570 - rate_pass_through_loss: 53.1278 - val_loss: 72.4750 - val_distortion_loss: 0.0688 - val_rate_loss: 38.0917 - val_distortion_pass_through_loss: 0.0687 - val_rate_pass_through_loss: 38.1038
Epoch 4/15
469/469 [==============================] - 9s 20ms/step - loss: 71.8922 - distortion_loss: 0.0601 - rate_loss: 41.8549 - distortion_pass_through_loss: 0.0601 - rate_pass_through_loss: 41.8529 - val_loss: 64.1137 - val_distortion_loss: 0.0785 - val_rate_loss: 24.8831 - val_distortion_pass_through_loss: 0.0784 - val_rate_pass_through_loss: 24.8904
Epoch 5/15
469/469 [==============================] - 9s 20ms/step - loss: 66.2340 - distortion_loss: 0.0629 - rate_loss: 34.7989 - distortion_pass_through_loss: 0.0629 - rate_pass_through_loss: 34.7976 - val_loss: 58.3210 - val_distortion_loss: 0.0801 - val_rate_loss: 18.2739 - val_distortion_pass_through_loss: 0.0801 - val_rate_pass_through_loss: 18.2635
Epoch 6/15
469/469 [==============================] - 9s 20ms/step - loss: 62.6940 - distortion_loss: 0.0649 - rate_loss: 30.2491 - distortion_pass_through_loss: 0.0649 - rate_pass_through_loss: 30.2479 - val_loss: 54.5973 - val_distortion_loss: 0.0814 - val_rate_loss: 13.8960 - val_distortion_pass_through_loss: 0.0813 - val_rate_pass_through_loss: 13.9088
Epoch 7/15
469/469 [==============================] - 9s 20ms/step - loss: 60.2058 - distortion_loss: 0.0663 - rate_loss: 27.0736 - distortion_pass_through_loss: 0.0663 - rate_pass_through_loss: 27.0724 - val_loss: 51.4746 - val_distortion_loss: 0.0775 - val_rate_loss: 12.7405 - val_distortion_pass_through_loss: 0.0775 - val_rate_pass_through_loss: 12.7322
Epoch 8/15
469/469 [==============================] - 9s 20ms/step - loss: 58.0009 - distortion_loss: 0.0666 - rate_loss: 24.7138 - distortion_pass_through_loss: 0.0666 - rate_pass_through_loss: 24.7133 - val_loss: 49.4696 - val_distortion_loss: 0.0734 - val_rate_loss: 12.7863 - val_distortion_pass_through_loss: 0.0734 - val_rate_pass_through_loss: 12.7662
Epoch 9/15
469/469 [==============================] - 9s 20ms/step - loss: 55.9776 - distortion_loss: 0.0662 - rate_loss: 22.8791 - distortion_pass_through_loss: 0.0662 - rate_pass_through_loss: 22.8781 - val_loss: 48.1887 - val_distortion_loss: 0.0700 - val_rate_loss: 13.1711 - val_distortion_pass_through_loss: 0.0701 - val_rate_pass_through_loss: 13.1664
Epoch 10/15
469/469 [==============================] - 9s 20ms/step - loss: 54.0630 - distortion_loss: 0.0652 - rate_loss: 21.4487 - distortion_pass_through_loss: 0.0652 - rate_pass_through_loss: 21.4477 - val_loss: 47.5509 - val_distortion_loss: 0.0689 - val_rate_loss: 13.0906 - val_distortion_pass_through_loss: 0.0689 - val_rate_pass_through_loss: 13.0824
Epoch 11/15
469/469 [==============================] - 9s 20ms/step - loss: 52.5058 - distortion_loss: 0.0641 - rate_loss: 20.4323 - distortion_pass_through_loss: 0.0641 - rate_pass_through_loss: 20.4320 - val_loss: 47.0983 - val_distortion_loss: 0.0660 - val_rate_loss: 14.0991 - val_distortion_pass_through_loss: 0.0660 - val_rate_pass_through_loss: 14.0968
Epoch 12/15
469/469 [==============================] - 9s 20ms/step - loss: 51.2286 - distortion_loss: 0.0632 - rate_loss: 19.6388 - distortion_pass_through_loss: 0.0632 - rate_pass_through_loss: 19.6387 - val_loss: 46.6349 - val_distortion_loss: 0.0643 - val_rate_loss: 14.4766 - val_distortion_pass_through_loss: 0.0643 - val_rate_pass_through_loss: 14.4723
Epoch 13/15
469/469 [==============================] - 9s 20ms/step - loss: 50.2295 - distortion_loss: 0.0624 - rate_loss: 19.0260 - distortion_pass_through_loss: 0.0624 - rate_pass_through_loss: 19.0255 - val_loss: 46.3414 - val_distortion_loss: 0.0644 - val_rate_loss: 14.1262 - val_distortion_pass_through_loss: 0.0644 - val_rate_pass_through_loss: 14.1229
Epoch 14/15
469/469 [==============================] - 9s 19ms/step - loss: 49.4662 - distortion_loss: 0.0619 - rate_loss: 18.5085 - distortion_pass_through_loss: 0.0619 - rate_pass_through_loss: 18.5085 - val_loss: 46.0598 - val_distortion_loss: 0.0640 - val_rate_loss: 14.0695 - val_distortion_pass_through_loss: 0.0640 - val_rate_pass_through_loss: 14.0671
Epoch 15/15
469/469 [==============================] - 9s 20ms/step - loss: 48.8475 - distortion_loss: 0.0615 - rate_loss: 18.0763 - distortion_pass_through_loss: 0.0615 - rate_pass_through_loss: 18.0759 - val_loss: 45.9036 - val_distortion_loss: 0.0638 - val_rate_loss: 13.9967 - val_distortion_pass_through_loss: 0.0639 - val_rate_pass_through_loss: 13.9826

png

The bit rate of our code goes down, as does the fidelity of the digits. However, most of the digits remain recognizable.

Let's reduce \(\lambda\) further.

train_and_visualize_model(lmbda=300)
Epoch 1/15
469/469 [==============================] - ETA: 0s - loss: 113.9453 - distortion_loss: 0.0765 - rate_loss: 91.0087 - distortion_pass_through_loss: 0.0764 - rate_pass_through_loss: 91.0019
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 11s 20ms/step - loss: 113.9453 - distortion_loss: 0.0765 - rate_loss: 91.0087 - distortion_pass_through_loss: 0.0764 - rate_pass_through_loss: 91.0019 - val_loss: 96.5798 - val_distortion_loss: 0.0668 - val_rate_loss: 76.5345 - val_distortion_pass_through_loss: 0.0669 - val_rate_pass_through_loss: 76.5314
Epoch 2/15
469/469 [==============================] - 9s 19ms/step - loss: 85.7681 - distortion_loss: 0.0609 - rate_loss: 67.4997 - distortion_pass_through_loss: 0.0609 - rate_pass_through_loss: 67.4941 - val_loss: 73.9959 - val_distortion_loss: 0.0764 - val_rate_loss: 51.0617 - val_distortion_pass_through_loss: 0.0764 - val_rate_pass_through_loss: 51.0703
Epoch 3/15
469/469 [==============================] - 9s 20ms/step - loss: 68.7888 - distortion_loss: 0.0645 - rate_loss: 49.4513 - distortion_pass_through_loss: 0.0645 - rate_pass_through_loss: 49.4474 - val_loss: 57.9014 - val_distortion_loss: 0.0860 - val_rate_loss: 32.1110 - val_distortion_pass_through_loss: 0.0860 - val_rate_pass_through_loss: 32.1095
Epoch 4/15
469/469 [==============================] - 9s 19ms/step - loss: 58.2579 - distortion_loss: 0.0691 - rate_loss: 37.5138 - distortion_pass_through_loss: 0.0691 - rate_pass_through_loss: 37.5116 - val_loss: 49.0095 - val_distortion_loss: 0.0996 - val_rate_loss: 19.1215 - val_distortion_pass_through_loss: 0.0997 - val_rate_pass_through_loss: 19.1154
Epoch 5/15
469/469 [==============================] - 9s 19ms/step - loss: 52.0423 - distortion_loss: 0.0736 - rate_loss: 29.9553 - distortion_pass_through_loss: 0.0736 - rate_pass_through_loss: 29.9535 - val_loss: 42.9857 - val_distortion_loss: 0.1058 - val_rate_loss: 11.2561 - val_distortion_pass_through_loss: 0.1058 - val_rate_pass_through_loss: 11.2495
Epoch 6/15
469/469 [==============================] - 9s 19ms/step - loss: 48.1614 - distortion_loss: 0.0773 - rate_loss: 24.9860 - distortion_pass_through_loss: 0.0773 - rate_pass_through_loss: 24.9847 - val_loss: 39.5561 - val_distortion_loss: 0.1074 - val_rate_loss: 7.3465 - val_distortion_pass_through_loss: 0.1074 - val_rate_pass_through_loss: 7.3374
Epoch 7/15
469/469 [==============================] - 9s 20ms/step - loss: 45.4303 - distortion_loss: 0.0800 - rate_loss: 21.4250 - distortion_pass_through_loss: 0.0800 - rate_pass_through_loss: 21.4242 - val_loss: 36.2512 - val_distortion_loss: 0.1000 - val_rate_loss: 6.2472 - val_distortion_pass_through_loss: 0.1001 - val_rate_pass_through_loss: 6.2349
Epoch 8/15
469/469 [==============================] - 9s 19ms/step - loss: 43.2415 - distortion_loss: 0.0816 - rate_loss: 18.7648 - distortion_pass_through_loss: 0.0816 - rate_pass_through_loss: 18.7640 - val_loss: 34.6368 - val_distortion_loss: 0.0951 - val_rate_loss: 6.0988 - val_distortion_pass_through_loss: 0.0951 - val_rate_pass_through_loss: 6.0970
Epoch 9/15
469/469 [==============================] - 9s 20ms/step - loss: 41.3811 - distortion_loss: 0.0823 - rate_loss: 16.6768 - distortion_pass_through_loss: 0.0823 - rate_pass_through_loss: 16.6763 - val_loss: 33.9006 - val_distortion_loss: 0.0924 - val_rate_loss: 6.1897 - val_distortion_pass_through_loss: 0.0923 - val_rate_pass_through_loss: 6.1920
Epoch 10/15
469/469 [==============================] - 9s 19ms/step - loss: 39.6697 - distortion_loss: 0.0818 - rate_loss: 15.1418 - distortion_pass_through_loss: 0.0818 - rate_pass_through_loss: 15.1415 - val_loss: 33.1051 - val_distortion_loss: 0.0859 - val_rate_loss: 7.3338 - val_distortion_pass_through_loss: 0.0859 - val_rate_pass_through_loss: 7.3265
Epoch 11/15
469/469 [==============================] - 9s 19ms/step - loss: 38.1470 - distortion_loss: 0.0804 - rate_loss: 14.0248 - distortion_pass_through_loss: 0.0804 - rate_pass_through_loss: 14.0245 - val_loss: 32.6648 - val_distortion_loss: 0.0827 - val_rate_loss: 7.8640 - val_distortion_pass_through_loss: 0.0827 - val_rate_pass_through_loss: 7.8598
Epoch 12/15
469/469 [==============================] - 9s 19ms/step - loss: 36.9021 - distortion_loss: 0.0790 - rate_loss: 13.2012 - distortion_pass_through_loss: 0.0790 - rate_pass_through_loss: 13.2009 - val_loss: 32.3153 - val_distortion_loss: 0.0811 - val_rate_loss: 7.9952 - val_distortion_pass_through_loss: 0.0811 - val_rate_pass_through_loss: 7.9855
Epoch 13/15
469/469 [==============================] - 9s 20ms/step - loss: 35.9055 - distortion_loss: 0.0778 - rate_loss: 12.5702 - distortion_pass_through_loss: 0.0778 - rate_pass_through_loss: 12.5696 - val_loss: 32.1251 - val_distortion_loss: 0.0800 - val_rate_loss: 8.1255 - val_distortion_pass_through_loss: 0.0800 - val_rate_pass_through_loss: 8.1183
Epoch 14/15
469/469 [==============================] - 9s 20ms/step - loss: 35.1264 - distortion_loss: 0.0768 - rate_loss: 12.0743 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 12.0742 - val_loss: 31.9446 - val_distortion_loss: 0.0774 - val_rate_loss: 8.7320 - val_distortion_pass_through_loss: 0.0774 - val_rate_pass_through_loss: 8.7187
Epoch 15/15
469/469 [==============================] - 9s 19ms/step - loss: 34.5168 - distortion_loss: 0.0761 - rate_loss: 11.6742 - distortion_pass_through_loss: 0.0761 - rate_pass_through_loss: 11.6742 - val_loss: 31.8502 - val_distortion_loss: 0.0768 - val_rate_loss: 8.8203 - val_distortion_pass_through_loss: 0.0768 - val_rate_pass_through_loss: 8.8122

png

The strings begin to get much shorter now, on the order of one byte per digit. However, this comes at a cost. More digits are becoming unrecognizable.

This demonstrates that this model is agnostic to human perceptions of error, it just measures the absolute deviation in terms of pixel values. To achieve a better perceived image quality, we would need to replace the pixel loss with a perceptual loss.

Use the decoder as a generative model.

If we feed the decoder random bits, this will effectively sample from the distribution that the model learned to represent digits.

First, re-instantiate the compressor/decompressor without a sanity check that would detect if the input string isn't completely decoded.

compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)

Now, feed long enough random strings into the decompressor so that it can decode/sample digits from them.

import os

strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)

fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
  axes[i].imshow(tf.squeeze(samples[i]))
  axes[i].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

png