Decorator to register a KL divergence implementation function.
tfp.substrates.jax.distributions.RegisterKL(
dist_cls_a, dist_cls_b
)
Usage:
@distributions.RegisterKL(distributions.Normal, distributions.Normal)
def _kl_normal_mvn(norm_a, norm_b):
# Return KL(norm_a || norm_b)
Args |
dist_cls_a
|
the class of the first argument of the KL divergence.
|
dist_cls_b
|
the class of the second argument of the KL divergence.
|
Methods
__call__
View source
__call__(
kl_fn
)
Perform the KL registration.
Args |
kl_fn
|
The function to use for the KL divergence.
|
Raises |
TypeError
|
if kl_fn is not a callable.
|
ValueError
|
if a KL divergence function has already been registered for
the given argument classes.
|