Learned data compression

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook shows how to do lossy data compression using neural networks and TensorFlow Compression.

Lossy compression involves making a trade-off between rate, the expected number of bits needed to encode a sample, and distortion, the expected error in the reconstruction of the sample.

The examples below use an autoencoder-like model to compress images from the MNIST dataset. The method is based on the paper End-to-end Optimized Image Compression.

More background on learned data compression can be found in this paper targeted at people familiar with classical data compression, or this survey targeted at a machine learning audience.

Setup

Install Tensorflow Compression via pip.

# Installs the latest version of TFC compatible with the installed TF version.
pip install tensorflow-compression~=$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d\.\d).*/\1.0/sg')

Import library dependencies.

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds

Define the trainer model.

Because the model resembles an autoencoder, and we need to perform a different set of functions during training and inference, the setup is a little different from, say, a classifier.

The training model consists of three parts:

  • the analysis (or encoder) transform, converting from the image into a latent space,
  • the synthesis (or decoder) transform, converting from the latent space back into image space, and
  • a prior and entropy model, modeling the marginal probabilities of the latents.

First, define the transforms:

def make_analysis_transform(latent_dims):
  """Creates the analysis (encoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Conv2D(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2D(
          50, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          latent_dims, use_bias=True, activation=None, name="fc_2"),
  ], name="analysis_transform")
def make_synthesis_transform():
  """Creates the synthesis (decoder) transform."""
  return tf.keras.Sequential([
      tf.keras.layers.Dense(
          500, use_bias=True, activation="leaky_relu", name="fc_1"),
      tf.keras.layers.Dense(
          2450, use_bias=True, activation="leaky_relu", name="fc_2"),
      tf.keras.layers.Reshape((7, 7, 50)),
      tf.keras.layers.Conv2DTranspose(
          20, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_1"),
      tf.keras.layers.Conv2DTranspose(
          1, 5, use_bias=True, strides=2, padding="same",
          activation="leaky_relu", name="conv_2"),
  ], name="synthesis_transform")

The trainer holds an instance of both transforms, as well as the parameters of the prior.

Its call method is set up to compute:

  • rate, an estimate of the number of bits needed to represent the batch of digits, and
  • distortion, the mean absolute difference between the pixels of the original digits and their reconstructions.
class MNISTCompressionTrainer(tf.keras.Model):
  """Model that trains a compressor/decompressor for MNIST."""

  def __init__(self, latent_dims):
    super().__init__()
    self.analysis_transform = make_analysis_transform(latent_dims)
    self.synthesis_transform = make_synthesis_transform()
    self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))

  @property
  def prior(self):
    return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))

  def call(self, x, training):
    """Computes rate and distortion losses."""
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    x = tf.reshape(x, (-1, 28, 28, 1))

    # Compute latent space representation y, perturb it and model its entropy,
    # then compute the reconstructed pixel-level representation x_hat.
    y = self.analysis_transform(x)
    entropy_model = tfc.ContinuousBatchedEntropyModel(
        self.prior, coding_rank=1, compression=False)
    y_tilde, rate = entropy_model(y, training=training)
    x_tilde = self.synthesis_transform(y_tilde)

    # Average number of bits per MNIST digit.
    rate = tf.reduce_mean(rate)

    # Mean absolute difference across pixels.
    distortion = tf.reduce_mean(abs(x - x_tilde))

    return dict(rate=rate, distortion=distortion)

Compute rate and distortion.

Let's walk through this step by step, using one image from the training set. Load the MNIST dataset for training and validation:

training_dataset, validation_dataset = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=False,
)

And extract one image \(x\):

(x, _), = validation_dataset.take(1)

plt.imshow(tf.squeeze(x))
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)
2022-06-14 01:23:41.484874: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

To get the latent representation \(y\), we need to cast it to float32, add a batch dimension, and pass it through the analysis transform.

x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)

print("y:", y)
y: tf.Tensor(
[[-0.03596018  0.02976677  0.04654873 -0.04329199 -0.02063924  0.07961531
   0.06324859 -0.02948999 -0.08282233 -0.0033717 ]], shape=(1, 10), dtype=float32)

The latents will be quantized at test time. To model this in a differentiable way during training, we add uniform noise in the interval \((-.5, .5)\) and call the result \(\tilde y\). This is the same terminology as used in the paper End-to-end Optimized Image Compression.

y_tilde = y + tf.random.uniform(y.shape, -.5, .5)

print("y_tilde:", y_tilde)
y_tilde: tf.Tensor(
[[ 0.26787832 -0.336036   -0.36863232 -0.01119763 -0.21013504  0.46640497
   0.49586838  0.08665098  0.4023888   0.15769626]], shape=(1, 10), dtype=float32)

The "prior" is a probability density that we train to model the marginal distribution of the noisy latents. For example, it could be a set of independent logistic distributions with different scales for each latent dimension. tfc.NoisyLogistic accounts for the fact that the latents have additive noise. As the scale approaches zero, a logistic distribution approaches a dirac delta (spike), but the added noise causes the "noisy" distribution to approach the uniform distribution instead.

prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))

_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));

png

During training, tfc.ContinuousBatchedEntropyModel adds uniform noise, and uses the noise and the prior to compute a (differentiable) upper bound on the rate (the average number of bits necessary to encode the latent representation). That bound can be minimized as a loss.

entropy_model = tfc.ContinuousBatchedEntropyModel(
    prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)

print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.188126], shape=(1,), dtype=float32)
y_tilde: tf.Tensor(
[[-0.44882008  0.02260551 -0.2765367  -0.5319677   0.09461814 -0.18998818
   0.24599524 -0.44672346 -0.40305805  0.02805377]], shape=(1, 10), dtype=float32)

Lastly, the noisy latents are passed back through the synthesis transform to produce an image reconstruction \(\tilde x\). Distortion is the error between original image and reconstruction. Obviously, with the transforms untrained, the reconstruction is not very useful.

x_tilde = make_synthesis_transform()(y_tilde)

# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)

x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
plt.imshow(tf.squeeze(x_tilde))
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17110279, shape=(), dtype=float32)
Data type: <dtype: 'uint8'>
Shape: (28, 28, 1)

png

For every batch of digits, calling the MNISTCompressionTrainer produces the rate and distortion as an average over that batch:

(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)

print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
2022-06-14 01:23:46.909859: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
rate:  tf.Tensor(20.296253, shape=(), dtype=float32)
distortion:  tf.Tensor(0.14659302, shape=(), dtype=float32)

In the next section, we set up the model to do gradient descent on these two losses.

Train the model.

We compile the trainer in a way that it optimizes the rate–distortion Lagrangian, that is, a sum of rate and distortion, where one of the terms is weighted by Lagrange parameter \(\lambda\).

This loss function affects the different parts of the model differently:

  • The analysis transform is trained to produce a latent representation that achieves the desired trade-off between rate and distortion.
  • The synthesis transform is trained to minimize distortion, given the latent representation.
  • The parameters of the prior are trained to minimize the rate given the latent representation. This is identical to fitting the prior to the marginal distribution of latents in a maximum likelihood sense.
def pass_through_loss(_, x):
  # Since rate and distortion are unsupervised, the loss doesn't need a target.
  return x

def make_mnist_compression_trainer(lmbda, latent_dims=50):
  trainer = MNISTCompressionTrainer(latent_dims)
  trainer.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    # Just pass through rate and distortion as losses/metrics.
    loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
    metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
    loss_weights=dict(rate=1., distortion=lmbda),
  )
  return trainer

Next, train the model. The human annotations are not necessary here, since we just want to compress the images, so we drop them using a map and instead add "dummy" targets for rate and distortion.

def add_rd_targets(image, label):
  # Training is unsupervised, so labels aren't necessary here. However, we
  # need to add "dummy" targets for rate and distortion.
  return image, dict(rate=0., distortion=0.)

def train_mnist_model(lmbda):
  trainer = make_mnist_compression_trainer(lmbda)
  trainer.fit(
      training_dataset.map(add_rd_targets).batch(128).prefetch(8),
      epochs=15,
      validation_data=validation_dataset.map(add_rd_targets).batch(128).cache(),
      validation_freq=1,
      verbose=1,
  )
  return trainer

trainer = train_mnist_model(lmbda=2000)
Epoch 1/15
462/469 [============================>.] - ETA: 0s - loss: 220.1499 - distortion_loss: 0.0601 - rate_loss: 99.9370 - distortion_pass_through_loss: 0.0601 - rate_pass_through_loss: 99.9370
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 5s 7ms/step - loss: 219.5430 - distortion_loss: 0.0599 - rate_loss: 99.8240 - distortion_pass_through_loss: 0.0598 - rate_pass_through_loss: 99.8193 - val_loss: 176.7449 - val_distortion_loss: 0.0422 - val_rate_loss: 92.3387 - val_distortion_pass_through_loss: 0.0422 - val_rate_pass_through_loss: 92.3541
Epoch 2/15
469/469 [==============================] - 2s 5ms/step - loss: 166.8285 - distortion_loss: 0.0413 - rate_loss: 84.3090 - distortion_pass_through_loss: 0.0413 - rate_pass_through_loss: 84.3050 - val_loss: 157.4475 - val_distortion_loss: 0.0403 - val_rate_loss: 76.8713 - val_distortion_pass_through_loss: 0.0403 - val_rate_pass_through_loss: 76.8923
Epoch 3/15
469/469 [==============================] - 2s 5ms/step - loss: 151.9537 - distortion_loss: 0.0402 - rate_loss: 71.6350 - distortion_pass_through_loss: 0.0402 - rate_pass_through_loss: 71.6317 - val_loss: 145.8981 - val_distortion_loss: 0.0409 - val_rate_loss: 64.1213 - val_distortion_pass_through_loss: 0.0409 - val_rate_pass_through_loss: 64.1344
Epoch 4/15
469/469 [==============================] - 2s 5ms/step - loss: 142.8845 - distortion_loss: 0.0397 - rate_loss: 63.5405 - distortion_pass_through_loss: 0.0397 - rate_pass_through_loss: 63.5382 - val_loss: 136.7687 - val_distortion_loss: 0.0403 - val_rate_loss: 56.0768 - val_distortion_pass_through_loss: 0.0403 - val_rate_pass_through_loss: 56.0828
Epoch 5/15
469/469 [==============================] - 2s 5ms/step - loss: 137.0008 - distortion_loss: 0.0392 - rate_loss: 58.5897 - distortion_pass_through_loss: 0.0392 - rate_pass_through_loss: 58.5873 - val_loss: 131.0233 - val_distortion_loss: 0.0415 - val_rate_loss: 48.0806 - val_distortion_pass_through_loss: 0.0415 - val_rate_pass_through_loss: 48.0994
Epoch 6/15
469/469 [==============================] - 2s 5ms/step - loss: 132.9696 - distortion_loss: 0.0388 - rate_loss: 55.4426 - distortion_pass_through_loss: 0.0388 - rate_pass_through_loss: 55.4409 - val_loss: 126.9303 - val_distortion_loss: 0.0409 - val_rate_loss: 45.0330 - val_distortion_pass_through_loss: 0.0409 - val_rate_pass_through_loss: 45.0385
Epoch 7/15
469/469 [==============================] - 2s 5ms/step - loss: 129.9940 - distortion_loss: 0.0384 - rate_loss: 53.2685 - distortion_pass_through_loss: 0.0384 - rate_pass_through_loss: 53.2674 - val_loss: 124.2637 - val_distortion_loss: 0.0412 - val_rate_loss: 41.9056 - val_distortion_pass_through_loss: 0.0412 - val_rate_pass_through_loss: 41.9085
Epoch 8/15
469/469 [==============================] - 2s 5ms/step - loss: 127.4443 - distortion_loss: 0.0380 - rate_loss: 51.4851 - distortion_pass_through_loss: 0.0380 - rate_pass_through_loss: 51.4838 - val_loss: 120.1026 - val_distortion_loss: 0.0395 - val_rate_loss: 41.1668 - val_distortion_pass_through_loss: 0.0395 - val_rate_pass_through_loss: 41.1857
Epoch 9/15
469/469 [==============================] - 2s 5ms/step - loss: 125.0016 - distortion_loss: 0.0375 - rate_loss: 49.9920 - distortion_pass_through_loss: 0.0375 - rate_pass_through_loss: 49.9913 - val_loss: 118.0363 - val_distortion_loss: 0.0384 - val_rate_loss: 41.1388 - val_distortion_pass_through_loss: 0.0384 - val_rate_pass_through_loss: 41.1450
Epoch 10/15
469/469 [==============================] - 2s 5ms/step - loss: 122.7839 - distortion_loss: 0.0371 - rate_loss: 48.6244 - distortion_pass_through_loss: 0.0371 - rate_pass_through_loss: 48.6231 - val_loss: 116.1021 - val_distortion_loss: 0.0375 - val_rate_loss: 41.0209 - val_distortion_pass_through_loss: 0.0375 - val_rate_pass_through_loss: 41.0351
Epoch 11/15
469/469 [==============================] - 2s 5ms/step - loss: 120.6783 - distortion_loss: 0.0366 - rate_loss: 47.4875 - distortion_pass_through_loss: 0.0366 - rate_pass_through_loss: 47.4865 - val_loss: 115.2810 - val_distortion_loss: 0.0371 - val_rate_loss: 41.1618 - val_distortion_pass_through_loss: 0.0371 - val_rate_pass_through_loss: 41.1847
Epoch 12/15
469/469 [==============================] - 2s 5ms/step - loss: 119.0230 - distortion_loss: 0.0362 - rate_loss: 46.6046 - distortion_pass_through_loss: 0.0362 - rate_pass_through_loss: 46.6035 - val_loss: 113.7711 - val_distortion_loss: 0.0365 - val_rate_loss: 40.6715 - val_distortion_pass_through_loss: 0.0366 - val_rate_pass_through_loss: 40.6991
Epoch 13/15
469/469 [==============================] - 2s 5ms/step - loss: 117.4259 - distortion_loss: 0.0358 - rate_loss: 45.8302 - distortion_pass_through_loss: 0.0358 - rate_pass_through_loss: 45.8292 - val_loss: 113.6195 - val_distortion_loss: 0.0365 - val_rate_loss: 40.7041 - val_distortion_pass_through_loss: 0.0365 - val_rate_pass_through_loss: 40.7239
Epoch 14/15
469/469 [==============================] - 2s 5ms/step - loss: 116.2441 - distortion_loss: 0.0355 - rate_loss: 45.2506 - distortion_pass_through_loss: 0.0355 - rate_pass_through_loss: 45.2500 - val_loss: 113.2898 - val_distortion_loss: 0.0362 - val_rate_loss: 40.9044 - val_distortion_pass_through_loss: 0.0362 - val_rate_pass_through_loss: 40.9304
Epoch 15/15
469/469 [==============================] - 2s 5ms/step - loss: 115.3178 - distortion_loss: 0.0352 - rate_loss: 44.8268 - distortion_pass_through_loss: 0.0352 - rate_pass_through_loss: 44.8259 - val_loss: 112.4323 - val_distortion_loss: 0.0356 - val_rate_loss: 41.2075 - val_distortion_pass_through_loss: 0.0356 - val_rate_pass_through_loss: 41.2298

Compress some MNIST images.

For compression and decompression at test time, we split the trained model in two parts:

  • The encoder side consists of the analysis transform and the entropy model.
  • The decoder side consists of the synthesis transform and the same entropy model.

At test time, the latents will not have additive noise, but they will be quantized and then losslessly compressed, so we give them new names. We call them and the image reconstruction \(\hat x\) and \(\hat y\), respectively (following End-to-end Optimized Image Compression).

class MNISTCompressor(tf.keras.Model):
  """Compresses MNIST images to strings."""

  def __init__(self, analysis_transform, entropy_model):
    super().__init__()
    self.analysis_transform = analysis_transform
    self.entropy_model = entropy_model

  def call(self, x):
    # Ensure inputs are floats in the range (0, 1).
    x = tf.cast(x, self.compute_dtype) / 255.
    y = self.analysis_transform(x)
    # Also return the exact information content of each digit.
    _, bits = self.entropy_model(y, training=False)
    return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
  """Decompresses MNIST images from strings."""

  def __init__(self, entropy_model, synthesis_transform):
    super().__init__()
    self.entropy_model = entropy_model
    self.synthesis_transform = synthesis_transform

  def call(self, string):
    y_hat = self.entropy_model.decompress(string, ())
    x_hat = self.synthesis_transform(y_hat)
    # Scale and cast back to 8-bit integer.
    return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)

When instantiated with compression=True, the entropy model converts the learned prior into tables for a range coding algorithm. When calling compress(), this algorithm is invoked to convert the latent space vector into bit sequences. The length of each binary string approximates the information content of the latent (the negative log likelihood of the latent under the prior).

The entropy model for compression and decompression must be the same instance, because the range coding tables need to be exactly identical on both sides. Otherwise, decoding errors can occur.

def make_mnist_codec(trainer, **kwargs):
  # The entropy model must be created with `compression=True` and the same
  # instance must be shared between compressor and decompressor.
  entropy_model = tfc.ContinuousBatchedEntropyModel(
      trainer.prior, coding_rank=1, compression=True, **kwargs)
  compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
  decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
  return compressor, decompressor

compressor, decompressor = make_mnist_codec(trainer)

Grab 16 images from the validation dataset. You can select a different subset by changing the argument to skip.

(originals, _), = validation_dataset.batch(16).skip(3).take(1)

Compress them to strings, and keep track of each of their information content in bits.

strings, entropies = compressor(originals)

print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x244e73750429
Number of bits actually needed to represent it: 41.99

Decompress the images back from the strings.

reconstructions = decompressor(strings)

Display each of 16 original digits together with its compressed binary representation, and the reconstructed digit.

display_digits(originals, strings, entropies, reconstructions)

png

Note that the length of the encoded string differs from the information content of each digit.

This is because the range coding process works with discretized probabilities, and has a small amount of overhead. So, especially for short strings, the correspondence is only approximate. However, range coding is asymptotically optimal: in the limit, the expected bit count will approach the cross entropy (the expected information content), for which the rate term in the training model is an upper bound.

The rate–distortion trade-off

Above, the model was trained for a specific trade-off (given by lmbda=2000) between the average number of bits used to represent each digit and the incurred error in the reconstruction.

What happens when we repeat the experiment with different values?

Let's start by reducing \(\lambda\) to 500.

def train_and_visualize_model(lmbda):
  trainer = train_mnist_model(lmbda=lmbda)
  compressor, decompressor = make_mnist_codec(trainer)
  strings, entropies = compressor(originals)
  reconstructions = decompressor(strings)
  display_digits(originals, strings, entropies, reconstructions)

train_and_visualize_model(lmbda=500)
Epoch 1/15
460/469 [============================>.] - ETA: 0s - loss: 127.6725 - distortion_loss: 0.0697 - rate_loss: 92.8186 - distortion_pass_through_loss: 0.0697 - rate_pass_through_loss: 92.8186
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 4s 6ms/step - loss: 127.3147 - distortion_loss: 0.0694 - rate_loss: 92.6021 - distortion_pass_through_loss: 0.0694 - rate_pass_through_loss: 92.5958 - val_loss: 108.2520 - val_distortion_loss: 0.0575 - val_rate_loss: 79.5169 - val_distortion_pass_through_loss: 0.0575 - val_rate_pass_through_loss: 79.5210
Epoch 2/15
469/469 [==============================] - 2s 5ms/step - loss: 97.3892 - distortion_loss: 0.0542 - rate_loss: 70.2840 - distortion_pass_through_loss: 0.0542 - rate_pass_through_loss: 70.2788 - val_loss: 86.1725 - val_distortion_loss: 0.0594 - val_rate_loss: 56.4715 - val_distortion_pass_through_loss: 0.0595 - val_rate_pass_through_loss: 56.4720
Epoch 3/15
469/469 [==============================] - 2s 5ms/step - loss: 81.3976 - distortion_loss: 0.0565 - rate_loss: 53.1271 - distortion_pass_through_loss: 0.0565 - rate_pass_through_loss: 53.1236 - val_loss: 71.7689 - val_distortion_loss: 0.0657 - val_rate_loss: 38.9376 - val_distortion_pass_through_loss: 0.0657 - val_rate_pass_through_loss: 38.9333
Epoch 4/15
469/469 [==============================] - 2s 5ms/step - loss: 71.7039 - distortion_loss: 0.0596 - rate_loss: 41.8972 - distortion_pass_through_loss: 0.0596 - rate_pass_through_loss: 41.8951 - val_loss: 62.9125 - val_distortion_loss: 0.0743 - val_rate_loss: 25.7446 - val_distortion_pass_through_loss: 0.0744 - val_rate_pass_through_loss: 25.7398
Epoch 5/15
469/469 [==============================] - 2s 5ms/step - loss: 65.9941 - distortion_loss: 0.0623 - rate_loss: 34.8288 - distortion_pass_through_loss: 0.0623 - rate_pass_through_loss: 34.8277 - val_loss: 56.9836 - val_distortion_loss: 0.0770 - val_rate_loss: 18.5021 - val_distortion_pass_through_loss: 0.0770 - val_rate_pass_through_loss: 18.4913
Epoch 6/15
469/469 [==============================] - 2s 5ms/step - loss: 62.3329 - distortion_loss: 0.0642 - rate_loss: 30.2477 - distortion_pass_through_loss: 0.0642 - rate_pass_through_loss: 30.2465 - val_loss: 53.4482 - val_distortion_loss: 0.0779 - val_rate_loss: 14.5194 - val_distortion_pass_through_loss: 0.0779 - val_rate_pass_through_loss: 14.5169
Epoch 7/15
469/469 [==============================] - 2s 5ms/step - loss: 59.6933 - distortion_loss: 0.0653 - rate_loss: 27.0410 - distortion_pass_through_loss: 0.0653 - rate_pass_through_loss: 27.0403 - val_loss: 50.4753 - val_distortion_loss: 0.0759 - val_rate_loss: 12.5500 - val_distortion_pass_through_loss: 0.0759 - val_rate_pass_through_loss: 12.5563
Epoch 8/15
469/469 [==============================] - 2s 5ms/step - loss: 57.4602 - distortion_loss: 0.0657 - rate_loss: 24.5984 - distortion_pass_through_loss: 0.0657 - rate_pass_through_loss: 24.5975 - val_loss: 48.6026 - val_distortion_loss: 0.0723 - val_rate_loss: 12.4449 - val_distortion_pass_through_loss: 0.0723 - val_rate_pass_through_loss: 12.4473
Epoch 9/15
469/469 [==============================] - 2s 5ms/step - loss: 55.3320 - distortion_loss: 0.0653 - rate_loss: 22.6798 - distortion_pass_through_loss: 0.0653 - rate_pass_through_loss: 22.6792 - val_loss: 47.3538 - val_distortion_loss: 0.0683 - val_rate_loss: 13.1898 - val_distortion_pass_through_loss: 0.0683 - val_rate_pass_through_loss: 13.1987
Epoch 10/15
469/469 [==============================] - 2s 5ms/step - loss: 53.4230 - distortion_loss: 0.0645 - rate_loss: 21.1713 - distortion_pass_through_loss: 0.0645 - rate_pass_through_loss: 21.1709 - val_loss: 46.7967 - val_distortion_loss: 0.0657 - val_rate_loss: 13.9286 - val_distortion_pass_through_loss: 0.0657 - val_rate_pass_through_loss: 13.9347
Epoch 11/15
469/469 [==============================] - 2s 5ms/step - loss: 51.8072 - distortion_loss: 0.0635 - rate_loss: 20.0425 - distortion_pass_through_loss: 0.0635 - rate_pass_through_loss: 20.0420 - val_loss: 46.0393 - val_distortion_loss: 0.0637 - val_rate_loss: 14.1823 - val_distortion_pass_through_loss: 0.0637 - val_rate_pass_through_loss: 14.1948
Epoch 12/15
469/469 [==============================] - 2s 5ms/step - loss: 50.4670 - distortion_loss: 0.0626 - rate_loss: 19.1538 - distortion_pass_through_loss: 0.0626 - rate_pass_through_loss: 19.1536 - val_loss: 45.7597 - val_distortion_loss: 0.0624 - val_rate_loss: 14.5469 - val_distortion_pass_through_loss: 0.0624 - val_rate_pass_through_loss: 14.5495
Epoch 13/15
469/469 [==============================] - 2s 5ms/step - loss: 49.3925 - distortion_loss: 0.0619 - rate_loss: 18.4531 - distortion_pass_through_loss: 0.0619 - rate_pass_through_loss: 18.4527 - val_loss: 45.5513 - val_distortion_loss: 0.0629 - val_rate_loss: 14.1151 - val_distortion_pass_through_loss: 0.0629 - val_rate_pass_through_loss: 14.1163
Epoch 14/15
469/469 [==============================] - 2s 5ms/step - loss: 48.5220 - distortion_loss: 0.0612 - rate_loss: 17.9393 - distortion_pass_through_loss: 0.0612 - rate_pass_through_loss: 17.9388 - val_loss: 45.3540 - val_distortion_loss: 0.0615 - val_rate_loss: 14.6202 - val_distortion_pass_through_loss: 0.0615 - val_rate_pass_through_loss: 14.6301
Epoch 15/15
469/469 [==============================] - 2s 5ms/step - loss: 47.8747 - distortion_loss: 0.0607 - rate_loss: 17.5226 - distortion_pass_through_loss: 0.0607 - rate_pass_through_loss: 17.5222 - val_loss: 45.0173 - val_distortion_loss: 0.0601 - val_rate_loss: 14.9463 - val_distortion_pass_through_loss: 0.0601 - val_rate_pass_through_loss: 14.9539

png

The bit rate of our code goes down, as does the fidelity of the digits. However, most of the digits remain recognizable.

Let's reduce \(\lambda\) further.

train_and_visualize_model(lmbda=300)
Epoch 1/15
469/469 [==============================] - ETA: 0s - loss: 113.8783 - distortion_loss: 0.0762 - rate_loss: 91.0098 - distortion_pass_through_loss: 0.0762 - rate_pass_through_loss: 91.0032
WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive.
469/469 [==============================] - 4s 6ms/step - loss: 113.8783 - distortion_loss: 0.0762 - rate_loss: 91.0098 - distortion_pass_through_loss: 0.0762 - rate_pass_through_loss: 91.0032 - val_loss: 96.7003 - val_distortion_loss: 0.0691 - val_rate_loss: 75.9582 - val_distortion_pass_through_loss: 0.0691 - val_rate_pass_through_loss: 75.9580
Epoch 2/15
469/469 [==============================] - 2s 5ms/step - loss: 85.7851 - distortion_loss: 0.0611 - rate_loss: 67.4697 - distortion_pass_through_loss: 0.0610 - rate_pass_through_loss: 67.4642 - val_loss: 73.7545 - val_distortion_loss: 0.0759 - val_rate_loss: 50.9975 - val_distortion_pass_through_loss: 0.0759 - val_rate_pass_through_loss: 50.9936
Epoch 3/15
469/469 [==============================] - 2s 5ms/step - loss: 68.7572 - distortion_loss: 0.0644 - rate_loss: 49.4336 - distortion_pass_through_loss: 0.0644 - rate_pass_through_loss: 49.4297 - val_loss: 58.1989 - val_distortion_loss: 0.0892 - val_rate_loss: 31.4509 - val_distortion_pass_through_loss: 0.0891 - val_rate_pass_through_loss: 31.4538
Epoch 4/15
469/469 [==============================] - 2s 5ms/step - loss: 58.1805 - distortion_loss: 0.0689 - rate_loss: 37.5008 - distortion_pass_through_loss: 0.0689 - rate_pass_through_loss: 37.4983 - val_loss: 48.3966 - val_distortion_loss: 0.0972 - val_rate_loss: 19.2390 - val_distortion_pass_through_loss: 0.0972 - val_rate_pass_through_loss: 19.2377
Epoch 5/15
469/469 [==============================] - 2s 5ms/step - loss: 51.8958 - distortion_loss: 0.0733 - rate_loss: 29.9187 - distortion_pass_through_loss: 0.0733 - rate_pass_through_loss: 29.9171 - val_loss: 42.3829 - val_distortion_loss: 0.1035 - val_rate_loss: 11.3228 - val_distortion_pass_through_loss: 0.1036 - val_rate_pass_through_loss: 11.3198
Epoch 6/15
469/469 [==============================] - 2s 5ms/step - loss: 47.9622 - distortion_loss: 0.0767 - rate_loss: 24.9413 - distortion_pass_through_loss: 0.0767 - rate_pass_through_loss: 24.9403 - val_loss: 38.2396 - val_distortion_loss: 0.0998 - val_rate_loss: 8.2981 - val_distortion_pass_through_loss: 0.0998 - val_rate_pass_through_loss: 8.2938
Epoch 7/15
469/469 [==============================] - 2s 5ms/step - loss: 45.2256 - distortion_loss: 0.0794 - rate_loss: 21.3975 - distortion_pass_through_loss: 0.0794 - rate_pass_through_loss: 21.3967 - val_loss: 35.4203 - val_distortion_loss: 0.0951 - val_rate_loss: 6.8925 - val_distortion_pass_through_loss: 0.0951 - val_rate_pass_through_loss: 6.8879
Epoch 8/15
469/469 [==============================] - 2s 5ms/step - loss: 42.9767 - distortion_loss: 0.0809 - rate_loss: 18.7189 - distortion_pass_through_loss: 0.0809 - rate_pass_through_loss: 18.7183 - val_loss: 34.1213 - val_distortion_loss: 0.0924 - val_rate_loss: 6.4127 - val_distortion_pass_through_loss: 0.0924 - val_rate_pass_through_loss: 6.4081
Epoch 9/15
469/469 [==============================] - 2s 5ms/step - loss: 41.0035 - distortion_loss: 0.0812 - rate_loss: 16.6346 - distortion_pass_through_loss: 0.0812 - rate_pass_through_loss: 16.6344 - val_loss: 32.9052 - val_distortion_loss: 0.0862 - val_rate_loss: 7.0431 - val_distortion_pass_through_loss: 0.0863 - val_rate_pass_through_loss: 7.0373
Epoch 10/15
469/469 [==============================] - 2s 5ms/step - loss: 39.1986 - distortion_loss: 0.0804 - rate_loss: 15.0752 - distortion_pass_through_loss: 0.0804 - rate_pass_through_loss: 15.0745 - val_loss: 32.6588 - val_distortion_loss: 0.0855 - val_rate_loss: 7.0213 - val_distortion_pass_through_loss: 0.0855 - val_rate_pass_through_loss: 7.0164
Epoch 11/15
469/469 [==============================] - 2s 5ms/step - loss: 37.6487 - distortion_loss: 0.0790 - rate_loss: 13.9394 - distortion_pass_through_loss: 0.0790 - rate_pass_through_loss: 13.9389 - val_loss: 32.1648 - val_distortion_loss: 0.0817 - val_rate_loss: 7.6512 - val_distortion_pass_through_loss: 0.0818 - val_rate_pass_through_loss: 7.6501
Epoch 12/15
469/469 [==============================] - 2s 5ms/step - loss: 36.4101 - distortion_loss: 0.0778 - rate_loss: 13.0658 - distortion_pass_through_loss: 0.0778 - rate_pass_through_loss: 13.0653 - val_loss: 31.8924 - val_distortion_loss: 0.0803 - val_rate_loss: 7.7989 - val_distortion_pass_through_loss: 0.0804 - val_rate_pass_through_loss: 7.7980
Epoch 13/15
469/469 [==============================] - 2s 5ms/step - loss: 35.5168 - distortion_loss: 0.0770 - rate_loss: 12.4052 - distortion_pass_through_loss: 0.0770 - rate_pass_through_loss: 12.4051 - val_loss: 31.6670 - val_distortion_loss: 0.0777 - val_rate_loss: 8.3560 - val_distortion_pass_through_loss: 0.0777 - val_rate_pass_through_loss: 8.3633
Epoch 14/15
469/469 [==============================] - 2s 5ms/step - loss: 34.7672 - distortion_loss: 0.0765 - rate_loss: 11.8313 - distortion_pass_through_loss: 0.0765 - rate_pass_through_loss: 11.8310 - val_loss: 31.5442 - val_distortion_loss: 0.0764 - val_rate_loss: 8.6302 - val_distortion_pass_through_loss: 0.0765 - val_rate_pass_through_loss: 8.6235
Epoch 15/15
469/469 [==============================] - 2s 5ms/step - loss: 34.1881 - distortion_loss: 0.0762 - rate_loss: 11.3369 - distortion_pass_through_loss: 0.0762 - rate_pass_through_loss: 11.3367 - val_loss: 31.3543 - val_distortion_loss: 0.0768 - val_rate_loss: 8.3237 - val_distortion_pass_through_loss: 0.0768 - val_rate_pass_through_loss: 8.3218

png

The strings begin to get much shorter now, on the order of one byte per digit. However, this comes at a cost. More digits are becoming unrecognizable.

This demonstrates that this model is agnostic to human perceptions of error, it just measures the absolute deviation in terms of pixel values. To achieve a better perceived image quality, we would need to replace the pixel loss with a perceptual loss.