View source on GitHub |
Builds a structured surrogate posterior inspired by conjugate updating.
tfp.experimental.vi.build_asvi_surrogate_posterior_stateless(
prior,
mean_field=False,
initial_prior_weight=0.5,
prior_substitution_rules=tfp.experimental.vi.ASVI_DEFAULT_PRIOR_SUBSTITUTION_RULES
,
surrogate_rules=tfp.experimental.vi.ASVI_DEFAULT_SURROGATE_RULES
,
name=None
)
ASVI, or Automatic Structured Variational Inference, was proposed by Ambrogioni et al. (2020) [1] as a method of automatically constructing a surrogate posterior with the same structure as the prior. It does this by reparameterizing the variational family of the surrogate posterior by structuring each parameter according to the equation
prior_weight * prior_parameter + (1 - prior_weight) * mean_field_parameter
In this equation, prior_parameter
is a vector of prior parameters and
mean_field_parameter
is a vector of trainable parameters with the same
domain as prior_parameter
. prior_weight
is a vector of learnable
parameters where 0. <= prior_weight <= 1.
. When prior_weight =
0
, the surrogate posterior will be a mean-field surrogate, and when
prior_weight = 1.
, the surrogate posterior will be the prior. This convex
combination equation, inspired by conjugacy in exponential families, thus
allows the surrogate posterior to balance between the structure of the prior
and the structure of a mean-field approximation.
Returns | |
---|---|
init_fn
|
Python callable with signature initial_parameters = init_fn(seed) .
|
apply_fn
|
Python callable with signature instance = apply_fn(*parameters) .
|
Examples
Consider a Brownian motion model expressed as a JointDistribution:
prior_loc = 0.
innovation_noise = .1
def model_fn():
new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise)
for i in range(4):
new = yield tfd.Normal(loc=new, scale=innovation_noise)
prior = tfd.JointDistributionCoroutineAutoBatched(model_fn)
Let's use variational inference to approximate the posterior. We'll build a surrogate posterior distribution by feeding in the prior distribution.
surrogate_posterior =
tfp.experimental.vi.build_asvi_surrogate_posterior(prior)
This creates a trainable joint distribution, defined by variables in
surrogate_posterior.trainable_variables
. We use fit_surrogate_posterior
to fit this distribution by minimizing a divergence to the true posterior.
losses = tfp.vi.fit_surrogate_posterior(
target_log_prob_fn,
surrogate_posterior=surrogate_posterior,
num_steps=100,
optimizer=tf.optimizers.Adam(0.1),
sample_size=10)
# After optimization, samples from the surrogate will approximate
# samples from the true posterior.
samples = surrogate_posterior.sample(100)
posterior_mean = [tf.reduce_mean(x) for x in samples]
posterior_std = [tf.math.reduce_std(x) for x in samples]
References
[1]: Luca Ambrogioni, Kate Lin, Emily Fertig, Sharad Vikram, Max Hinne, Dave Moore, Marcel van Gerven. Automatic structured variational inference. arXiv preprint arXiv:2002.00643, 2020 https://arxiv.org/abs/2002.00643