Regularizer that adds a KL divergence penalty to the model loss.

When using Monte Carlo approximation (e.g., use_exact=False), it is presumed that the input distribution's concretization (i.e., tf.convert_to_tensor(distribution)) corresponds to a random sample. To override this behavior, set test_points_fn.


tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras
tfkl = tf.keras.layers

# Create a variational encoder and add a KL Divergence penalty to the
# loss that encourages marginal coherence with a unit-MVN (the "prior").
input_shape = [28, 28, 1]
encoded_size = 2
variational_encoder = tfk.Sequential([
    tfkl.Dense(10, activation='relu'),
        lambda s: s.sample(10),

distribution_b distribution instance corresponding to b as in KL[a, b]. The previous layer's output is presumed to be a Distribution instance and is a).
use_exact_kl Python bool indicating if KL divergence should be calculated exactly via tfp.distributions.kl_divergence or via Monte Carlo approximation. Default value: False.
test_points_reduce_axis int vector or scalar representing dimensions over which to reduce_mean while calculating the Monte Carlo approximation of the KL divergence. As is with all tf.reduce_* ops, None means reduce over all dimensions; () means reduce over none of them. Default value: () (i.e., no reduction).
test_points_fn Python callable taking a Distribution instance and returning a Tensor used for random test points to approximate the KL divergence. Default value: tf.convert_to_tensor.
weight Multiplier applied to the calculated KL divergence for each Keras batch member. Default value: None (i.e., do not weight each batch member).


name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
list(b.submodules) == [c]
list(c.submodules) == []



trainable_variables Sequence of trainable variables owned by this module and its submodules.


variables Sequence of variables owned by this module and its submodules.




Creates a regularizer from its config.

This method is the reverse of get_config, capable of instantiating the same regularizer from the config dictionary.

This method is used by Keras model_to_estimator, saving and loading models to HDF5 formats, Keras model cloning, some visualization utilities, and exporting models to and from JSON.

config A Python dictionary, typically the output of get_config.

A regularizer instance.


Returns the config of the regularizer.

An regularizer config is a Python dictionary (serializable) containing all configuration parameters of the regularizer. The same regularizer can be reinstantiated later (without any saved state) from this configuration.

This method is optional if you are just training and executing models, exporting to and from SavedModels, or using weight checkpoints.

This method is required for Keras model_to_estimator, saving and loading models to HDF5 formats, Keras model cloning, some visualization utilities, and exporting models to and from JSON.

Python dictionary.


Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

method The method to wrap.

The original method wrapped such that it enters the module's name scope.


View source

Compute a regularization penalty from an input tensor.