Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

SNAPER-HMC without step size adaptation.

Inherits From: TransitionKernel

This implements the SNAPER-HMC algorithm from [1], without the step size adaptation. This kernel learns a diagonal mass matrix and the trajectory length parameters of the Hamiltonian Monte Carlo (HMC) sampler using the Adaptive MCMC framework [2]. As with all adaptive MCMC algorithms, this kernel does not produce samples from the target distribution while adaptation is engaged, so be sure to set num_adaptation_steps parameter smaller than the number of burnin steps.

This kernel uses the SNAPER criterion (see tfp.experimental.mcmc.snaper_criterion for details) which has a principal- component parameter. This kernel learns it using a batched Oja's algorithm with a learning rate of principal_component_ema_factor / step where step is the iteration number.

The mass matrix is learned using a variant of the Welford's algorithm/Exponential Moving Average, with a decay rate set to step // state_ema_factor / (step // state_ema_factor + 1).

Learning the step size is a necessary component of a good HMC sampler, but it is not handled by this kernel. That adaptation can be provided by, for example, tfp.mcmc.SimpleStepSizeAdaptation or tfp.mcmc.DualAveragingSizeAdaptation.

To aid algorithm stability, the first few steps are taken with the number of leapfrog steps set to 1, turning the algorithm into Metropolis Adjusted Langevin Algorithm (MALA). This is controlled by the num_mala_steps argument.

Unlike some classical MCMC algorithms, this algorithm behaves best when the chains are initialized with very low variance. Initializing them all at one point is recommended.

SNAPER-HMC requires at least two chains to function.


Here we apply this kernel to a target with a known covariance structure and show that it recovers the principal component and the variances.

num_dims = 8
num_burnin_steps = 1000
num_adaptation_steps = int(num_burnin_steps * 0.8)
num_results = 500
num_chains = 64
step_size = 1e-2
num_mala_steps = 100

eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T)

_, eigs = np.linalg.eigh(covariance)
principal_component = eigs[:, -1]

gaussian = tfd.MultivariateNormalTriL(

kernel = tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo(
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    kernel, num_adaptation_steps=num_adaptation_steps)

def trace_fn(_, pkr):
  return {
          unnest.get_innermost(pkr, 'ema_principal_component'),
          unnest.get_innermost(pkr, 'ema_variance'),

init_x = tf.zeros([num_chains, num_dims])

chain, trace = tfp.mcmc.sample_chain(

# Close to `np.diag(covariance)`
# Close to `principal_component`, up to a sign.

# Compute sampler diagnostics.
tfp.mcmc.effective_sample_size(chain, cross_chain_dims=1)

# Compute downstream statistics.
tf.reduce_mean(chain, [0, 1])


[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for Learning HMC Trajectory Lengths. <>

[2]: Andrieu, Christophe, Thoms, Johannes. A tutorial on adaptive MCMC. Statistics and Computing, 2008. <>.

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
step_size Scalar float Tensor representing the step size for the leapfrog integrator.
num_adaptation_steps Scalar int Tensor number of initial steps during which to adjust the hyperparameters.
num_mala_steps Scalar int Tensor number of initial steps during which the number of leapfrog steps is clamped to 1, for stability.
max_leapfrog_steps Scalar int Tensor. Clips the number of leapfrog steps to this value.
trajectory_length_adaptation_rate Scalar float Tensor. How rapidly to adapt the trajectory length.
principal_component_ema_factor Scalar int Tensor. Factor controlling the principal component adaptation. Larger number corresponds to faster adaptation.
state_ema_factor Scalar int Tensor. Factor controlling the mass matrix adaptation. Larger number corresponds to faster adaptation.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
experimental_reduce_chain_axis_names A string or list of string names indicating which named axes to average cross-chain statistics over.
preconditioned_hamiltonian_monte_carlo_kwargs Additional keyword arguments to pass to PreconditionedHamiltonianMonteCarlo kernel.
gradient_based_trajectory_length_adaptation_kwargs Additional keyword arguments to pass to GradientBasedTrajectoryLengthAdaptation kernel.
validate_args Python bool. When True, kernel parameters are checked for validity. When False, invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this class. Default: 'snaper_hamiltonian_monte_carlo'.


experimental_shard_axis_names The shard axis names for members of the state.

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.















View source

Returns an object with the same type as returned by one_step(...)[1].

init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.


View source

Non-destructively creates a deep copy of the kernel.

**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).


View source

Returns a copy of the kernel with the provided shard axis names.

shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

A copy of the current kernel with the shard axis information.


View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.

next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.