Attend the Women in ML Symposium on December 7 Register now

tfp.experimental.distributions.ImportanceResample

Stay organized with collections Save and categorize content based on your preferences.

Models the distribution of finitely many importance-reweighted samples.

Inherits From: Distribution

This wrapper adapts a proposal distribution towards a target density using importance sampling. Given a proposal q, a target density p (which may be unnormalized), and an integer importance_sample_size, it models the result of the following sampling process:

  1. Draw importance_sample_size samples z[k] ~ q from the proposal.
  2. Compute an importance weight w[k] = p(z[k]) / q(z[k]) for each sample.
  3. Return a sample z[k*] selected with probability proportional to the importance weights, i.e., with k* ~ Categorical(probs=w/sum(w)).

In the limit where importance_sample_size -> inf, the result z[k*] of this procedure would be distributed according to the target density p. On the other hand, if importance_sample_size == 1, then the reweighting has no effect and the result z[k*] is simply a sample from q. Finite values of importance_sample_size describe distributions that are intermediate between p and q.

This distribution may also be understood as an explicit representation of the surrogate posterior that is implicitly assumed by importance-weighted variational objectives. [1, 2]

Examples

This distribution can be used directly for posterior inference via importance sampling:

tfd = tfp.distributions
tfed = tfp.experimental.distributions

def target_log_prob_fn(x):
  prior = tfd.Normal(loc=0., scale=1.).log_prob(x)
  # Multimodal likelihood.
  likelihood = tf.reduce_logsumexp(
    tfd.Normal(loc=x, scale=0.1).log_prob([-1., 1.]))
  return prior + likelihood

# Use importance sampling to infer an approximate posterior.
approximate_posterior = tfed.ImportanceResample(
  proposal_distribution=tfd.Normal(loc=0., scale=2.),
  target_log_prob_fn=target_log_prob_fn,
  importance_sample_size=100)

We can estimate posterior expectations directly using an importance-weighted sum of proposal samples:

# Directly compute expectations under the posterior via importance weights.
posterior_mean = approximate_posterior.self_normalized_expectation(
  lambda x: x, importance_sample_size=1000)
posterior_variance = approximate_posterior.self_normalized_expectation(
  lambda x: (x - posterior_mean)**2, importance_sample_size=1000)

Alternately, the same expectations can be estimated from explicit (unweighted) samples. Note that sampling may be expensive because it performs resampling internally. For example, to produce sample_size samples requires first proposing values of shape [sample_size, importance_sample_size] ([1000, 100] in the code below) and then resampling down to [sample_size], throwing most of the proposals away. For this reason you should prefer calling self_normalized_expectation over naive sampling to compute expectations.

posterior_samples = approximate_posterior.sample(1000)
posterior_mean_inefficient = tf.reduce_mean(posterior_samples)
posterior_variance_inefficient = tf.math.reduce_variance(posterior_samples)

# Calling `self_normalized_expectation` allows for a much lower `sample_size`
# because it uses the full set of `importance_sample_size` proposal samples to
# approximate the expectation at each of the `sample_size` Monte Carlo
# evaluations. This is formalized in Eq. 9 of [3].
posterior_mean_efficient = approximate_posterior.self_normalized_expectation(
  lambda x: x, sample_size=10)
posterior_variance_efficient = (
  approximate_posterior.self_normalized_expectation(
    lambda x: (x - posterior_mean_efficient)**2, sample_size=10))

The posterior (log-)density cannot be computed directly, but may be stochastically approximated. The prob and log_prob methods accept arguments seed and sample_size to control the variance of the approximation.

# Plot the posterior density.
from matplotlib import pylab as plt
xs = tf.linspace(-3., 3., 101)
probs = approximate_posterior.prob(xs, sample_size=10, seed=(42, 42))
plt.plot(xs, probs)

Connections to importance-weighted variational inference

Optimizing an importance-weighted variational bound provides a natural approach to choose a proposal distribution for importance sampling. Importance-weighted bounds are available directly in TFP via the importance_sample_size argument to tfp.vi.monte_carlo_variational_loss and tfp.vi.fit_surrogate_posterior. For example, we might improve on the example above by replacing the fixed proposal distribution with a learned proposal:

proposal_distribution = tfp.experimental.util.make_trainable(tfd.Normal)
importance_sample_size = 100
importance_weighted_losses = tfp.vi.fit_surrogate_posterior(
  target_log_prob_fn,
  surrogate_posterior=proposal_distribution,
  optimizer=tf.optimizers.Adam(0.1),
  num_steps=200,
  importance_sample_size=importance_sample_size)
approximate_posterior = tfed.ImportanceResample(
  proposal_distribution=proposal_distribution,
  target_log_prob_fn=target_log_prob_fn,
  importance_sample_size=importance_sample_size)

Note that although the importance-resampled approximate_posterior serves ultimately as the surrogate posterior, only the bare proposal distribution is passed as the surrogate_posterior argument to fit_surrogate_posterior. This is because the importance_sample_size argument tells fit_surrogate_posterior to compute an importance-weighted bound directly from the proposal distribution. Mathematically, it would be equivalent to omit the importance_sample_size argument and instead pass an ImportanceResample distribution as the surrogate posterior:

equivalent_but_less_efficient_losses = tfp.vi.fit_surrogate_posterior(
  target_log_prob_fn,
  surrogate_posterior=tfed.ImportanceResample(
    proposal_distribution=proposal_distribution,
    target_log_prob_fn=target_log_prob_fn,
    importance_sample_size=importance_sample_size),
  optimizer=tf.optimizers.Adam(0.1),
  num_steps=200)

but this approach is not recommended, because it performs redundant evaluations of the target_log_prob_fn compared to the direct bound shown above.

References

[1] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance Weighted Autoencoders. In International Conference on Learning Representations, 2016. https://arxiv.org/abs/1509.00519 [2] Chris Cremer, Quaid Morris, and David Duvenaud. Reinterpreting Importance-Weighted Autoencoders. In International Conference on Learning Representations, Workshop track, 2017. https://arxiv.org/abs/1704.02916 [3] Justin Domke, Daniel Sheldon. Importance Weighting and Variational Inference. In Neural Information Processing Systems (NIPS), 2018. https://arxiv.org/abs/1808.09034

proposal_distribution Instance of tfd.Distribution used to generate proposals. This may be a joint distribution.
target_log_prob_fn Python callable representation of a (potentially unnormalized) target log-density. This should accept samples from the proposal, i.e., lp = target_log_prob_fn(proposal_distribution.sample()).
importance_sample_size integer Tensor number of proposals used in the distribution of a single sample. Larger values better approximate the target distribution, at the cost of increased computation and memory usage.
sample_size integer Tensor number of Monte Carlo samples used to reduce variance in stochastic methods such as log_prob, prob, and self_normalized_expectation. Note that increasing importance_sample_size leads to a more accurate approximation of the target distribution (reducing bias and variance), while increasing sample_size improves the precision of estimates under the intermediate distribution corresponding to a particular finite importance_sample_size (i.e., it reduces variance only and does not affect the sampling distribution). If unsure, it's generally safe to leave sample_size at its default value of 1 and focus on increasing importance_sample_size instead. Default value: 1.
stochastic_approximation_seed optional PRNG key used in stochastic approximations for methods such as log_prob, prob, and self_normalized_expectation. This seed does not affect sampling. Default value: None.
validate_args Python bool. Whether to validate input with asserts. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
name Python str name for this distribution. If None, defaults to 'importance_resample'. Default value: None.

allow_nan_stats Python bool describing behavior when a stat is undefined.

Stats return +/- infinity when it makes sense. E.g., the variance of a Cauchy distribution is infinity. However, sometimes the statistic is undefined, e.g., if a distribution's pdf does not achieve a maximum within the support of the distribution, the mode is undefined. If the mean is undefined, then by definition the variance is undefined. E.g. the mean for Student's T for df = 1 is undefined (no clear way to say it is either + or - infinity), so the variance = E[(X - mean)**2] is also undefined.

batch_shape Shape of a single sample from a single event index as a TensorShape.

May be partially defined or unknown.

The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.

dtype The DType of Tensors handled by this Distribution.
event_shape Shape of a single sample from a single batch as a TensorShape.

May be partially defined or unknown.

experimental_shard_axis_names The list or structure of lists of active shard axis names.
importance_sample_size

name Name prepended to all ops created by this Distribution.
name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.

parameters Dictionary of parameters used to instantiate this Distribution.
proposal_distribution

reparameterization_type Describes how samples from the distribution are reparameterized.

Currently this is one of the static instances tfd.FULLY_REPARAMETERIZED or tfd.NOT_REPARAMETERIZED.

sample_size

stochastic_approximation_seed

submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

target_log_prob_fn

trainable_variables Sequence of trainable variables owned by this module and its submodules.

validate_args Python bool indicating possibly expensive checks are enabled.
variables Sequence of variables owned by this module and its submodules.

Methods