oryx.distributions.RegisterKL

Decorator to register a KL divergence implementation function.

Usage:

@distributions.RegisterKL(distributions.Normal, distributions.Normal)
def _kl_normal_mvn(norm_a, norm_b):
  # Return KL(norm_a || norm_b)

dist_cls_a the class of the first argument of the KL divergence.
dist_cls_b the class of the second argument of the KL divergence.

Methods

__call__

Perform the KL registration.

Args
kl_fn The function to use for the KL divergence.

Returns
kl_fn

Raises
TypeError if kl_fn is not a callable.
ValueError if a KL divergence function has already been registered for the given argument classes.