TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.vi.monte_carlo_variational_loss

View source on GitHub

Monte-Carlo approximation of an f-Divergence variational loss.

tfp.vi.monte_carlo_variational_loss(
    target_log_prob_fn,
    surrogate_posterior,
    sample_size=1,
    discrepancy_fn=tfp.vi.kl_reverse,
    use_reparametrization=None,
    seed=None,
    name=None
)

Variational losses measure the divergence between an unnormalized target distribution p (provided via target_log_prob_fn) and a surrogate distribution q (provided as surrogate_posterior). When the target distribution is an unnormalized posterior from conditioning a model on data, minimizing the loss with respect to the parameters of surrogate_posterior performs approximate posterior inference.

This function defines divergences of the form E_q[discrepancy_fn(log p(z) - log q(z))], sometimes known as f-divergences [1, 2]. In the special case discrepancy_fn(logu) == -logu (the default tfp.vi.kl_reverse), this is the reverse Kullback-Liebler divergence KL[q||p], whose negation applied to an unnormalized p is the widely-used evidence lower bound (ELBO) [3]. Other cases of interest available under tfp.vi include the forward KL[p||q] (given by tfp.vi.kl_forward(logu) == exp(logu) * logu), total variation distance, Amari alpha-divergences, and more.

Args:

  • target_log_prob_fn: Python callable that takes a set of Tensor arguments and returns a Tensor log-density. Given q_sample = surrogate_posterior.sample(sample_size), this will be called as target_log_prob_fn(*q_sample) if q_sample is a list or a tuple, target_log_prob_fn(**q_sample) if q_sample is a dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor. It should support batched evaluation, i.e., should return a result of shape [sample_size].
  • surrogate_posterior: A tfp.distributions.Distribution instance defining a variational posterior (could be a tfd.JointDistribution). Crucially, the distribution's log_prob and (if reparameterizeable) sample methods must directly invoke all ops that generate gradients to the underlying variables. One way to ensure this is to use tfp.util.DeferredTensor to represent any parameters defined as transformations of unconstrained variables, so that the transformations execute at runtime instead of at distribution creation.
  • sample_size: Integer scalar number of Monte Carlo samples used to approximate the variational divergence. Larger values may stabilize the optimization, but at higher cost per step in time and memory. Default value: 1.
  • discrepancy_fn: Python callable representing a Csiszar f function in in log-space. That is, discrepancy_fn(log(u)) = f(u), where f is convex in u. Default value: tfp.vi.kl_reverse.
  • use_reparametrization: Python bool. When None (the default), automatically set to: surrogate_posterior.reparameterization_type == tfd.FULLY_REPARAMETERIZED. When True uses the standard Monte-Carlo average. When False uses the score-gradient trick. (See above for details.) When False, consider using csiszar_vimco.
  • seed: Python int seed for surrogate_posterior.sample.
  • name: Python str name prefixed to Ops created by this function.

Returns:

  • monte_carlo_variational_loss: float-like Tensor Monte Carlo approximation of the Csiszar f-Divergence.

Raises:

  • ValueError: if surrogate_posterior is not a reparameterized distribution and use_reparametrization = True. A distribution is said to be "reparameterized" when its samples are generated by transforming the samples of another distribution that does not depend on the first distribution's parameters. This property ensures the gradient with respect to parameters is valid.
  • TypeError: if target_log_prob_fn is not a Python callable.

Csiszar f-divergences

A Csiszar function f is a convex function from R^+ (the positive reals) to R. The Csiszar f-Divergence is given by:

D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
                ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
                           where x_j ~iid q(X)

For example, f = lambda u: -log(u) recovers KL[q||p], while f = lambda u: u * log(u) recovers the forward KL[p||q]. These and other functions are available in tfp.vi.

Tricks: Reparameterization and Score-Gradient

When q is "reparameterized", i.e., a diffeomorphic transformation of a parameterless distribution (e.g., Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)), we can swap gradient and expectation, i.e., grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n } where S_n=Avg{s_i} and s_i = f(x_i), x_i ~iid q(X).

However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In this circumstance using the Score-Gradient trick results in an unbiased gradient, i.e.,

grad[ E_q[f(X)] ]
= grad[ int dx q(x) f(x) ]
= int dx grad[ q(x) f(x) ]
= int dx [ q'(x) f(x) + q(x) f'(x) ]
= int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
= int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
= E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]

Unless q.reparameterization_type != tfd.FULLY_REPARAMETERIZED it is usually preferable to set use_reparametrization = True.

Example Application:

The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,

f(p(x)) =  f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
        <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
        := D_f[p(x, Z), q(Z | x)]

The inequality follows from the fact that the "perspective" of f, i.e., (s, t) |-> t f(s / t)), is convex in (s, t) when s/t in domain(f) and t is a real. Since the above framework includes the popular Evidence Lower BOund (ELBO) as a special case, i.e., f(u) = -log(u), we call this framework "Evidence Divergence Bound Optimization" (EDBO).

References:

[2]: Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142. [3]: Bishop, Christopher M. Pattern Recognition and Machine Learning. Springer, 2006.