tfp.substrates.jax.mcmc.sample_annealed_importance_chain

Runs annealed importance sampling (AIS) to estimate normalizing constants.

This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'proposal' distribution:

exp(proposal_log_prob_fn(x) - proposal_log_normalizer)

and the target distribution:

exp(target_log_prob_fn(x) - target_log_normalizer),

accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution:

E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer).

num_steps Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator.
proposal_log_prob_fn Python callable that returns the log density of the initial distribution.
target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)).
make_kernel_fn Python callable which returns a TransitionKernel-like object. Must take one argument representing the TransitionKernel's target_log_prob_fn. The target_log_prob_fn argument represents the TransitionKernel's target log distribution. Note: sample_annealed_importance_chain creates a new target_log_prob_fn which is an interpolation between the supplied target_log_prob_fn and proposal_log_prob_fn; it is this interpolated function which is used as an argument to make_kernel_fn.
parallel_iterations The number of iterations allowed to run in parallel. It must be a positive integer. See tf.while_loop for more details.
seed PRNG seed; see tfp.random.sanitize_seed for details.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'sample_annealed_importance_chain').

next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input current_state.
ais_weights Tensor with the estimated weight(s). Has shape matching target_log_prob_fn(current_state).
kernel_results collections.namedtuple of internal calculations used to advance the chain.

Examples

Estimate the normalizing constant of a log-gamma distribution.
tfd = tfp.distributions

# Run 100 AIS chains in parallel
num_chains = 100
dims = 20
dtype = np.float32

proposal = tfd.MultivariateNormalDiag(
   loc=tf.zeros([dims], dtype=dtype))

target = tfd.TransformedDistribution(
  distribution=tfd.Sample(
      tfd.Gamma(concentration=dtype(2), rate=dtype(3)),
      sample_shape=[dims])
  bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()))

chains_state, ais_weights, kernels_results = (
    tfp.mcmc.sample_annealed_importance_chain(
        num_steps=1000,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=target.log_prob,
        current_state=proposal.sample(num_chains),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.2,
          num_leapfrog_steps=2)))

log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                            - np.log(num_chains))
log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
Estimate marginal likelihood of a Bayesian regression model.
tfd = tfp.distributions

def make_prior(dims, dtype):
  return tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype))

def make_likelihood(weights, x):
  return tfd.MultivariateNormalDiag(
      loc=tf.tensordot(weights, x, axes=[[0], [-1]]))

# Run 100 AIS chains in parallel
num_chains = 100
dims = 10
dtype = np.float32

# Make training data.
x = np.random.randn(num_chains, dims).astype(dtype)
true_weights = np.random.randn(dims).astype(dtype)
y = np.dot(x, true_weights) + np.random.randn(num_chains)

# Setup model.
prior = make_prior(dims, dtype)
def target_log_prob_fn(weights):
  return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

proposal = tfd.MultivariateNormalDiag(
    loc=tf.zeros(dims, dtype))

weight_samples, ais_weights, kernel_results = (
    tfp.mcmc.sample_annealed_importance_chain(
      num_steps=1000,
      proposal_log_prob_fn=proposal.log_prob,
      target_log_prob_fn=target_log_prob_fn
      current_state=tf.zeros([num_chains, dims], dtype),
      make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=tlp_fn,
        step_size=0.1,
        num_leapfrog_steps=2)))
log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                           - np.log(num_chains))