TFP Probabilistic Layers: Variational Auto Encoder

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In this example we show how to fit a Variational Autoencoder using TFP's "probabilistic layers."

Dependencies & Prerequisites

Import

Make things Fast!

Before we dive in, let's make sure we're using a GPU for this demo.

To do this, select "Runtime" -> "Change runtime type" -> "Hardware accelerator" -> "GPU".

The following snippet will verify that we have access to a GPU.

if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))
SUCCESS: Found GPU: /device:GPU:0

Load Dataset

datasets, datasets_info = tfds.load(name='mnist',
                                    with_info=True,
                                    as_supervised=False)

def _preprocess(sample):
  image = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
  return image, image

train_dataset = (datasets['train']
                 .map(_preprocess)
                 .batch(256)
                 .prefetch(tf.data.AUTOTUNE)
                 .shuffle(int(10e3)))
eval_dataset = (datasets['test']
                .map(_preprocess)
                .batch(256)
                .prefetch(tf.data.AUTOTUNE))

Note that preprocess() above returns image, image rather than just image because Keras is set up for discriminative models with an (example, label) input format, i.e. \(p\theta(y|x)\). Since the goal of the VAE is to recover the input x from x itself (i.e. \(p_\theta(x|x)\)), the data pair is (example, example).

VAE Code Golf

Specify model.

input_shape = datasets_info.features['image'].shape
encoded_size = 16
base_depth = 32
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims=1)
encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl.Conv2D(base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(4 * encoded_size, 7, strides=1,
                padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Flatten(),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior)),
])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
decoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=[encoded_size]),
    tfkl.Reshape([1, 1, encoded_size]),
    tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1,
                         padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(filters=1, kernel_size=5, strides=1,
                padding='same', activation=None),
    tfkl.Flatten(),
    tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])
vae = tfk.Model(inputs=encoder.inputs,
                outputs=decoder(encoder.outputs[0]))

Do inference.

negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae.compile(optimizer=tf_keras.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

_ = vae.fit(train_dataset,
            epochs=15,
            validation_data=eval_dataset)
Epoch 1/15
235/235 [==============================] - 14s 61ms/step - loss: 206.5541 - val_loss: 163.1924
Epoch 2/15
235/235 [==============================] - 14s 59ms/step - loss: 151.1891 - val_loss: 143.6748
Epoch 3/15
235/235 [==============================] - 14s 58ms/step - loss: 141.3275 - val_loss: 137.9188
Epoch 4/15
235/235 [==============================] - 14s 58ms/step - loss: 136.7453 - val_loss: 133.2726
Epoch 5/15
235/235 [==============================] - 14s 58ms/step - loss: 132.3803 - val_loss: 131.8343
Epoch 6/15
235/235 [==============================] - 14s 58ms/step - loss: 129.2451 - val_loss: 127.1935
Epoch 7/15
235/235 [==============================] - 14s 59ms/step - loss: 126.0975 - val_loss: 123.6789
Epoch 8/15
235/235 [==============================] - 14s 58ms/step - loss: 124.0565 - val_loss: 122.5058
Epoch 9/15
235/235 [==============================] - 14s 58ms/step - loss: 122.9974 - val_loss: 121.9544
Epoch 10/15
235/235 [==============================] - 14s 58ms/step - loss: 121.7349 - val_loss: 120.8735
Epoch 11/15
235/235 [==============================] - 14s 58ms/step - loss: 121.0856 - val_loss: 120.1340
Epoch 12/15
235/235 [==============================] - 14s 58ms/step - loss: 120.2232 - val_loss: 121.3554
Epoch 13/15
235/235 [==============================] - 14s 58ms/step - loss: 119.8123 - val_loss: 119.2351
Epoch 14/15
235/235 [==============================] - 14s 58ms/step - loss: 119.2685 - val_loss: 118.2133
Epoch 15/15
235/235 [==============================] - 14s 59ms/step - loss: 118.8895 - val_loss: 119.4771

Look Ma, No HandsTensors!

# We'll just examine ten random digits.
x = next(iter(eval_dataset))[0][:10]
xhat = vae(x)
assert isinstance(xhat, tfd.Distribution)

Image Plot Util

print('Originals:')
display_imgs(x)

print('Decoded Random Samples:')
display_imgs(xhat.sample())

print('Decoded Modes:')
display_imgs(xhat.mode())

print('Decoded Means:')
display_imgs(xhat.mean())
Originals:

png

Decoded Random Samples:

png

Decoded Modes:

png

Decoded Means:

png

# Now, let's generate ten never-before-seen digits.
z = prior.sample(10)
xtilde = decoder(z)
assert isinstance(xtilde, tfd.Distribution)
print('Randomly Generated Samples:')
display_imgs(xtilde.sample())

print('Randomly Generated Modes:')
display_imgs(xtilde.mode())

print('Randomly Generated Means:')
display_imgs(xtilde.mean())
Randomly Generated Samples:

png

Randomly Generated Modes:

png

Randomly Generated Means:

png