Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

Monte-Carlo approximation of an f-Divergence variational loss.

Variational losses measure the divergence between an unnormalized target distribution p (provided via target_log_prob_fn) and a surrogate distribution q (provided as surrogate_posterior). When the target distribution is an unnormalized posterior from conditioning a model on data, minimizing the loss with respect to the parameters of surrogate_posterior performs approximate posterior inference.

This function defines losses of the form E_q[discrepancy_fn(log(u))], where u = p(z) / q(z) in the (default) case where importance_sample_size == 1, and u = mean([p(z[k]) / q(z[k]) for k in range(importance_sample_size)])) more generally. These losses are sometimes known as f-divergences [1, 2].

The default behavior (discrepancy_fn ==, where = lambda logu: -logu, and importance_sample_size == 1) computes an unbiased estimate of the standard evidence lower bound (ELBO) [3]. The bound may be tightened by setting importance_sample_size > 1 [4], and the variance of the estimate reduced by setting sample_size > 1. Other discrepancies of interest available under include the forward KL[p||q], total variation distance, Amari alpha-divergences, and more.

target_log_prob_fn Python callable that takes a set of Tensor arguments and returns a Tensor log-density. Given q_sample = surrogate_posterior.sample(sample_size), this will be called as target_log_prob_fn(*q_sample) if q_sample is a list or a tuple, target_log_prob_fn(**q_sample) if q_sample is a dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor. It should support batched evaluation, i.e., should return a result of shape [sample_size].
surrogate_posterior A tfp.distributions.Distribution instance defining a variational posterior (could be a tfd.JointDistribution). Crucially, the distribution's log_prob and (if reparameterizeable) sample methods must directly invoke all ops that generate gradients to the underlying variables. One way to ensure this is to use tfp.util.TransformedVariable and/or tfp.util.DeferredTensor to represent any parameters defined as transformations of unconstrained variables, so that the transformations execute at runtime instead of at distribution creation.
sample_size Integer scalar number of Monte Carlo samples used to approximate the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: 1.
importance_sample_size Python int number of terms used to define an importance-weighted divergence. If importance_sample_size > 1, then the surrogate_posterior is optimized to function as an importance-sampling proposal distribution. In this case it often makes sense to use importance sampling to approximate posterior expectations (see for an example). Default value: 1.
discrepancy_fn Python callable representing a Csiszar f function in in log-space. That is, discrepancy_fn(log(u)) = f(u), where f is convex in u. Default value:
use_reparameterization Python bool. When None (the default), automatically set to: surrogate_posterior.reparameterization_type == tfd.FULLY_REPARAMETERIZED. When True uses the standard Monte-Carlo average. When False uses the score-gradient trick. (See above for details.) When False, consider using csiszar_vimco.
seed PRNG seed for surrogate_posterior.sample; see tfp.random.sanitize_seed for details.
name Python str name prefixed to Ops created by this function.

monte_carlo_variational_loss float-like Tensor Monte Carlo approximation of the Csiszar f-Divergence.

ValueError if surrogate_posterior is not a reparameterized distribution and use_reparameterization = True. A distribution is said to be "reparameterized" when its samples are generated by transforming the samples of another distribution that does not depend on the first distribution's parameters. This property ensures the gradient with respect to parameters is valid.
TypeError if target_log_prob_fn is not a Python callable.

Csiszar f-divergences

A Csiszar function f is a convex function from R^+ (the positive reals) to R. The Csiszar f-Divergence is given by:

D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                           where x_j ~iid q(X)

For example, f = lambda u: -log(u) recovers KL[q||p], while f = lambda u: u * log(u) recovers the forward KL[p||q]. These and other functions are available in

Tricks: Reparameterization and Score-Gradient

When q is "reparameterized", i.e., a diffeomorphic transformation of a parameterless distribution (e.g., Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)), we can swap gradient and expectation, i.e., grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n } where S_n=Avg{s_i} and s_i = f(x_i), x_i ~iid q(X).

However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In this circumstance using the Score-Gradient trick results in an unbiased gradient, i.e.,

grad[ E_q[f(X)] ]
= grad[ int dx q(x) f(x) ]
= int dx grad[ q(x) f(x) ]
= int dx [ q'(x) f(x) + q(x) f'(x) ]
= int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
= int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
= E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]

Unless q.reparameterization_type != tfd.FULLY_REPARAMETERIZED it is usually preferable to set use_reparameterization = True.

Example Application:

The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,

f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
        <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
        := D_f[p(x, Z), q(Z | x)]

The inequality follows from the fact that the "perspective" of f, i.e., (s, t) |-> t f(s / t)), is convex in (s, t) when s/t in domain(f) and t is a real. Since the above framework includes the popular Evidence Lower BOund (ELBO) as a special case, i.e., f(u) = -log(u), we call this framework "Evidence Divergence Bound Optimization" (EDBO).


[2]: Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.

[3]: Christopher M. Bishop. Pattern Recognition and Machine Learning. Springer, 2006.

[4] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted Autoencoders. In International Conference on Learning Representations, 2016.