|View source on GitHub|
Fit a surrogate posterior to a target (unnormalized) log density. (deprecated arguments)
tfp.vi.fit_surrogate_posterior( target_log_prob_fn, surrogate_posterior, optimizer, num_steps, convergence_criterion=None, trace_fn=_trace_loss, variational_loss_fn=None, discrepancy_fn=tfp.vi.kl_reverse, sample_size=1, importance_sample_size=1, trainable_variables=None, jit_compile=None, seed=None, name='fit_surrogate_posterior' )
Used in the notebooks
The default behavior constructs and minimizes the negative variational evidence lower bound (ELBO), given by
q_samples = surrogate_posterior.sample(num_draws) elbo_loss = -tf.reduce_mean( target_log_prob_fn(q_samples) - surrogate_posterior.log_prob(q_samples))
This corresponds to minimizing the 'reverse' Kullback-Liebler divergence
KL[q||p]) between the variational distribution and the unnormalized
target_log_prob_fn, and defines a lower bound on the marginal log
log p(x) >= -elbo_loss. 
More generally, this function supports fitting variational distributions that minimize any Csiszar f-divergence.
Python callable that takes a set of
Optimizer instance to use. This may be a TF1-style
Optional instance of
Python callable with signature
Normal-Normal model. We'll first consider a simple model
z ~ N(0, 1),
x ~ N(z, 1), where we suppose we are interested in the
p(z | x=5):
import tensorflow_probability as tfp from tensorflow_probability import distributions as tfd def log_prob(z, x): return tfd.Normal(0., 1.).log_prob(z) + tfd.Normal(z, 1.).log_prob(x) conditioned_log_prob = lambda z: log_prob(z, x=5.)
The posterior is itself normal by conjugacy, and can be computed
N(loc=5/2., scale=1/sqrt(2)). But suppose we don't want
to bother doing the math: we can use variational inference instead!
q_z = tfp.experimental.util.make_trainable(tfd.Normal, name='q_z') losses = tfp.vi.fit_surrogate_posterior( conditioned_log_prob, surrogate_posterior=q_z, optimizer=tf.optimizers.Adam(learning_rate=0.1), num_steps=100) print(q_z.mean(), q_z.stddev()) # => approximately [2.5, 1/sqrt(2)]
Note that we ensure positive scale by using a softplus transformation of
the underlying variable, invoked via
TransformedVariable. Deferring the
transformation causes it to be applied upon evaluation of the distribution's
methods, creating a gradient to the underlying variable. If we
had simply specified
TransformedVariable, fitting would fail because calls to
q.sample would never access the underlying variable. In
general, transformations of trainable parameters must be deferred to runtime,
DeferredTensor or by the callable
mechanisms available in joint distribution classes (demonstrated below).
Custom loss function. Suppose we prefer to fit the same model using
the forward KL divergence
KL[p||q]. We can pass a custom discrepancy
losses = tfp.vi.fit_surrogate_posterior( conditioned_log_prob, surrogate_posterior=q_z, optimizer=tf.optimizers.Adam(learning_rate=0.1), num_steps=100, discrepancy_fn=tfp.vi.kl_forward)
Note that in practice this may have substantially higher-variance gradients than the reverse KL.
Importance weighting. A surrogate posterior may be corrected by interpreting it as a proposal for an importance sampler. That is, one can use weighted samples from the surrogate to estimate expectations under the true posterior:
zs, q_log_prob = surrogate_posterior.experimental_sample_and_log_prob( num_samples) # Naive expectation under the surrogate posterior. expected_x = tf.reduce_mean(f(zs), axis=0) # Importance-weighted estimate of the expectation under the true posterior. self_normalized_log_weights = tf.nn.log_softmax( target_log_prob_fn(zs) - q_log_prob) expected_x = tf.reduce_sum( tf.exp(self_normalized_log_weights) * f(zs), axis=0)
Any distribution may be used as a proposal, but it is often natural to
consider surrogates that were themselves fit by optimizing an
importance-weighted variational objective , which directly optimizes the
surrogate's effectiveness as an proposal distribution. This may be specified
importance_sample_size > 1. The importance-weighted objective
may favor different characteristics than the original objective.
For example, effective proposals are generally overdispersed, whereas a
surrogate optimizing reverse KL would otherwise tend to be underdispersed.
Although importance sampling is guaranteed to tighten the variational bound, some research has found that this does not necessarily improve the quality of deep generative models, because it also introduces gradient noise that can lead to a weaker training signal . As always, evaluation is important to choose the approach that works best for a particular task.
When using an importance-weighted loss to fit a surrogate, it is also recommended to apply importance sampling when computing expectations under that surrogate.
# Fit `q` with an importance-weighted variational loss. losses = tfp.vi.fit_surrogate_posterior( conditioned_log_prob, surrogate_posterior=q_z, importance_sample_size=10, optimizer=tf.optimizers.Adam(learning_rate=0.1), num_steps=200) # Estimate posterior statistics with importance sampling. zs, q_log_prob = q_z.experimental_sample_and_log_prob(1000) self_normalized_log_weights = tf.nn.log_softmax( conditioned_log_prob(zs) - q_log_prob) posterior_mean = tf.reduce_sum( tf.exp(self_normalized_log_weights) * zs, axis=0) posterior_variance = tf.reduce_sum( tf.exp(self_normalized_log_weights) * (zs - posterior_mean)**2, axis=0)
Inhomogeneous Poisson Process. For a more interesting example, let's
consider a model with multiple latent variables as well as trainable
parameters in the model itself. Given observed counts
y from spatial
X, consider an inhomogeneous Poisson process model
log_rates = GaussianProcess(index_points=X); y = Poisson(exp(log_rates))
in which the latent (log) rates are spatially correlated following a Gaussian
process. We'll fit a variational model to the latent rates while also
optimizing the GP kernel hyperparameters (largely for illustration; in
practice we might prefer to 'be Bayesian' about these parameters and include
them as latents in our model and variational posterior). First we define
the model, including trainable variables:
# Toy 1D data. index_points = np.array([-10., -7.2, -4., -0.1, 0.1, 4., 6.2, 9.]).reshape( [-1, 1]).astype(np.float32) observed_counts = np.array( [100, 90, 60, 13, 18, 37, 55, 42]).astype(np.float32) # Trainable GP hyperparameters. kernel_log_amplitude = tf.Variable(0., name='kernel_log_amplitude') kernel_log_lengthscale = tf.Variable(0., name='kernel_log_lengthscale') observation_noise_log_scale = tf.Variable( 0., name='observation_noise_log_scale') # Generative model. Root = tfd.JointDistributionCoroutine.Root def model_fn(): kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( amplitude=tf.exp(kernel_log_amplitude), length_scale=tf.exp(kernel_log_lengthscale)) latent_log_rates = yield Root(tfd.GaussianProcess( kernel, index_points=index_points, observation_noise_variance=tf.exp(observation_noise_log_scale), name='latent_log_rates')) y = yield tfd.Independent(tfd.Poisson(log_rate=latent_log_rates, name='y'), reinterpreted_batch_ndims=1) model = tfd.JointDistributionCoroutine(model_fn)
Next we define a variational distribution. We incorporate the observations directly into the variational model using the 'trick' of representing them by a deterministic distribution (observe that the true posterior on an observed value is in fact a point mass at the observed value).
logit_locs = tf.Variable(tf.zeros(observed_counts.shape), name='logit_locs') logit_softplus_scales = tf.Variable(tf.ones(observed_counts.shape) * -4, name='logit_softplus_scales') def variational_model_fn(): latent_rates = yield Root(tfd.Independent( tfd.Normal(loc=logit_locs, scale=tf.nn.softplus(logit_softplus_scales)), reinterpreted_batch_ndims=1)) y = yield tfd.VectorDeterministic(observed_counts) q = tfd.JointDistributionCoroutine(variational_model_fn)
Note that here we could apply transforms to variables without using
DeferredTensor because the
JointDistributionCoroutine argument is a
function, i.e., executed "on demand." (The same is true when
distribution-making functions are supplied to
JointDistributionNamed. That is, as long as variables are transformed
within the callable, they will appear on the gradient tape when
q.sample() are invoked.
Finally, we fit the variational posterior and model variables jointly: by not
trainable_variables, the optimization will
automatically include all variables accessed. We'll
use a custom
trace_fn to see how the kernel amplitudes and a set of sampled
latent rates with fixed seed evolve during the course of the optimization:
losses, log_amplitude_path, sample_path = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=lambda *args: model.log_prob(args), surrogate_posterior=q, optimizer=tf.optimizers.Adam(learning_rate=0.1), sample_size=1, num_steps=500, trace_fn=lambda loss, grads, vars: (loss, kernel_log_amplitude, q.sample(5, seed=42)))
: Christopher M. Bishop. Pattern Recognition and Machine Learning. Springer, 2006.
 Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted Autoencoders. In International Conference on Learning Representations, 2016. https://arxiv.org/abs/1509.00519
 Tom Rainforth, Adam R. Kosiorek, Tuan Anh Le, Chris J. Maddison, Maximilian Igl, Frank Wood, and Yee Whye Teh. Tighter Variational Bounds are Not Necessarily Better. In International Conference on Machine Learning (ICML), 2018. https://arxiv.org/abs/1802.04537