# tfp.mcmc.sample_annealed_importance_chain

Runs annealed importance sampling (AIS) to estimate normalizing constants.

This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'proposal' distribution:

`exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`

and the target distribution:

`exp(target_log_prob_fn(x) - target_log_normalizer)`,

accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution:

`E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.

`num_steps` Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator.
`proposal_log_prob_fn` Python callable that returns the log density of the initial distribution.
`target_log_prob_fn` Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
`current_state` `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
`make_kernel_fn` Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_annealed_importance_chain` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`.
`parallel_iterations` The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details.
`seed` Optional, a seed for reproducible sampling.
`name` Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_annealed_importance_chain').

`next_state` `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`.
`ais_weights` Tensor with the estimated weight(s). Has shape matching `target_log_prob_fn(current_state)`.
`kernel_results` `collections.namedtuple` of internal calculations used to advance the chain.

#### Examples

##### Estimate the normalizing constant of a log-gamma distribution.
``````tfd = tfp.distributions

# Run 100 AIS chains in parallel
num_chains = 100
dims = 20
dtype = np.float32

proposal = tfd.MultivariateNormalDiag(
loc=tf.zeros([dims], dtype=dtype))

target = tfd.TransformedDistribution(
distribution=tfd.Gamma(concentration=dtype(2),
rate=dtype(3)),
bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()),
event_shape=[dims])

chains_state, ais_weights, kernels_results = (
tfp.mcmc.sample_annealed_importance_chain(
num_steps=1000,
proposal_log_prob_fn=proposal.log_prob,
target_log_prob_fn=target.log_prob,
current_state=proposal.sample(num_chains),
make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=tlp_fn,
step_size=0.2,
num_leapfrog_steps=2)))

log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)

- np.log(num_chains))
log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
``````
##### Estimate marginal likelihood of a Bayesian regression model.
``````tfd = tfp.distributions

def make_prior(dims, dtype):
return tfd.MultivariateNormalDiag(
loc=tf.zeros(dims, dtype))

def make_likelihood(weights, x):
return tfd.MultivariateNormalDiag(
loc=tf.tensordot(weights, x, axes=[, [-1]]))

# Run 100 AIS chains in parallel
num_chains = 100
dims = 10
dtype = np.float32

# Make training data.
x = np.random.randn(num_chains, dims).astype(dtype)
true_weights = np.random.randn(dims).astype(dtype)
y = np.dot(x, true_weights) + np.random.randn(num_chains)

# Setup model.
prior = make_prior(dims, dtype)
def target_log_prob_fn(weights):
return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

proposal = tfd.MultivariateNormalDiag(
loc=tf.zeros(dims, dtype))

weight_samples, ais_weights, kernel_results = (
tfp.mcmc.sample_annealed_importance_chain(
num_steps=1000,
proposal_log_prob_fn=proposal.log_prob,
target_log_prob_fn=target_log_prob_fn
current_state=tf.zeros([num_chains, dims], dtype),
make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=tlp_fn,
step_size=0.1,
num_leapfrog_steps=2)))
log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)

- np.log(num_chains))
``````