Missed TensorFlow World? Check out the recap. Learn more

tfp.layers.KLDivergenceRegularizer

View source on GitHub

Class KLDivergenceRegularizer

Regularizer that adds a KL divergence penalty to the model loss.

Aliases:

When using Monte Carlo approximation (e.g., use_exact=False), it is presumed that the input distribution's concretization (i.e., tf.convert_to_tensor(distribution)) corresponds to a random sample. To override this behavior, set test_points_fn.

Example

tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras
tfkl = tf.keras.layers

# Create a variational encoder and add a KL Divergence penalty to the
# loss that encourages marginal coherence with a unit-MVN (the "prior").
input_shape = [28, 28, 1]
encoded_size = 2
variational_encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    tfkl.Flatten(),
    tfkl.Dense(10, activation='relu'),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size)),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        lambda s: s.sample(10),
        activity_regularizer=tfpl.KLDivergenceRegularizer(
           tfd.MultivariateNormalDiag(loc=tf.zeros(encoded_size)),
           weight=num_train_samples)),
])

__init__

View source

__init__(
    distribution_b,
    use_exact_kl=False,
    test_points_reduce_axis=(),
    test_points_fn=tf.convert_to_tensor,
    weight=None
)

Initialize the KLDivergenceRegularizer regularizer.

Args:

  • distribution_b: distribution instance corresponding to b as in KL[a, b]. The previous layer's output is presumed to be a Distribution instance and is a).
  • use_exact_kl: Python bool indicating if KL divergence should be calculated exactly via tfp.distributions.kl_divergence or via Monte Carlo approximation. Default value: False.
  • test_points_reduce_axis: int vector or scalar representing dimensions over which to reduce_mean while calculating the Monte Carlo approximation of the KL divergence. As is with all tf.reduce_* ops, None means reduce over all dimensions; () means reduce over none of them. Default value: () (i.e., no reduction).
  • test_points_fn: Python callable taking a Distribution instance and returning a Tensor used for random test points to approximate the KL divergence. Default value: tf.convert_to_tensor.
  • weight: Multiplier applied to the calculated KL divergence for each Keras batch member. Default value: None (i.e., do not weight each batch member).

Methods

__call__

View source

__call__(distribution_a)

Call self as a function.

from_config

from_config(
    cls,
    config
)