tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo

Hamiltonian Monte Carlo, with given momentum distribution.

Inherits From: HamiltonianMonteCarlo, TransitionKernel

See tfp.mcmc.HamiltonianMonteCarlo for details on HMC.

HMC produces samples much more efficiently if properly preconditioned. This can be done by choosing a momentum distribution with covariance equal to the inverse of the state's covariance.

Examples:

Simple chain with warm-up.

In this example we can use an estimate of the target covariance to sample efficiently with HMC.

import tensorflow as tf
import tensorflow_probability as tfp
tfed = tfp.experimental.distributions

# Suppose we have a target log prob fn, as well as an estimate of its
# covariance.
log_prob_fn = ...
cov_estimate = ...

# We want the mass matrix to be the *inverse* of the covariance estimate,
# so we can use the symmetric square root:
momentum_distribution = (
    tfed.MultivariateNormalPrecisionFactorLinearOperator(
        precision_factor=tf.linalg.LinearOperatorLowerTriangular(
            tf.linalg.cholesky(cov_estimate),
        ),
        precision=tf.linalg.LinearOperatorFullMatrix(cov_estimate),
)

# Run standard HMC below
num_burnin_steps = 100
num_results = 1000
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
      target_log_prob_fn=log_prob_fn,
      momentum_distribution=momentum_distribution,
      step_size=0.3,
      num_leapfrog_steps=10),
    num_adaptation_steps=int(num_burnin_steps * 0.8))

@tf.function
def run_chain_and_compute_ess():
  draws = tfp.mcmc.sample_chain(
      num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=tf.zeros(3),  # 3 chains.
      kernel=adaptive_hmc,
      trace_fn=None)
  return tfp.mcmc.effective_sample_size(draws, cross_chain_dims=1)

run_chain_and_compute_ess()  # Something close to 3 x 1000.
Estimate parameters of a more complicated distribution.

This demonstrates using multiple state parts, and reshaping a tfde.MultivariateNormalPrecisionFactorLinearOperator to use with a scalar or a non-square shape (in this case, [2, 3, 4]).

mvn = tfd.JointDistributionSequential([
    tfd.Normal(0., 0.1),
    tfd.Normal(0., 10.),
    tfd.Independent(tfd.Normal(tf.fill([2, 3, 4], 3.), 10.),
                    reinterpreted_batch_ndims=3)])

reshape_to_scalar = tfp.bijectors.Reshape(event_shape_out=[])
reshape_to_234 = tfp.bijectors.Reshape(event_shape_out=[2, 3, 4])
momentum_distribution = tfd.JointDistributionSequential([
    tfd.Normal(0., 10.),
    tfd.Normal(0., 0.1),
    reshape_to_234(
        tfde.MultivariateNormalPrecisionFactorLinearOperator(
            0., tf.linalg.LinearOperatorDiag(tf.fill([24], 10.))))
])
num_burnin_steps = 100
num_results = 1000
adaptive_hmc = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
  target_log_prob_fn=mvn.log_prob,
  momentum_distribution=momentum_distribution,
  step_size=0.3,
  num_leapfrog_steps=10)

@tf.function
def run_chain_and_compute_ess():
  draws = tfp.mcmc.sample_chain(
      num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=mvn.sample(),
      kernel=adaptive_hmc,
      trace_fn=None)
  return tfp.mcmc.effective_sample_size(draws)

run_chain_and_compute_ess()  # [1000, 1000, 1000 * tf.ones([2, 3, 4])]

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
step_size Tensor or Python list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable.
num_leapfrog_steps Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to step_size * num_leapfrog_steps.
momentum_distribution A tfp.distributions.Distribution instance to draw momentum from. Defaults to normal distributions with identity covariance.
state_gradients_are_stopped Python bool indicating that the proposed new state be run through tf.stop_gradient. This is particularly useful when combining optimization over samples from the HMC chain. Default value: False (i.e., do not apply stop_gradient).
step_size_update_fn Python callable taking current step_size (typically a tf.Variable) and kernel_results (typically collections.namedtuple) and returns updated step_size (Tensors). Default value: None (i.e., do not update step_size automatically).
store_parameters_in_results If True, then step_size, momentum_distribution, and num_leapfrog_steps are written to and read from eponymous fields in the kernel results objects returned from one_step and bootstrap_results. This allows wrapper kernels to adjust those parameters on the fly. In case this is True, the momentum_distribution must be a CompositeTensor. See tfp.experimental.auto_composite. This is incompatible with step_size_update_fn, which must be set to None.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'phmc_kernel').

experimental_shard_axis_names The shard axis names for members of the state.
is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

num_leapfrog_steps Returns the num_leapfrog_steps parameter.

If store_parameters_in_results argument to the initializer was set to True, this only returns the value of the num_leapfrog_steps placed in the kernel results by the bootstrap_results method. The actual num_leapfrog_steps in that situation is governed by the previous_kernel_results argument to one_step method.

parameters Return dict of __init__ arguments and their values.
state_gradients_are_stopped

step_size Returns the step_size parameter.

If store_parameters_in_results argument to the initializer was set to True, this only returns the value of the step_size placed in the kernel results by the bootstrap_results method. The actual step size in that situation is governed by the previous_kernel_results argument to one_step method.

target_log_prob_fn

Methods

bootstrap_results

View source

Creates initial previous_kernel_results using a supplied state.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

one_step

View source

Runs one iteration of Hamiltonian Monte Carlo.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)).
previous_kernel_results collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.)
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
kernel_results collections.namedtuple of internal calculations used to advance the chain.

Raises
ValueError if there isn't one step_size or a list with same length as current_state.