Stay organized with collections Save and categorize content based on your preferences.

Runs one step of Hamiltonian Monte Carlo.

Inherits From: TransitionKernel

Used in the notebooks

Used in the tutorials

Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm that takes a series of gradient-informed steps to produce a Metropolis proposal. This class implements one random HMC step from a given current_state. Mathematical details and derivations can be found in [Neal (2011)][1].

The one_step function can update multiple chains in parallel. It assumes that all leftmost dimensions of current_state index independent chain states (and are therefore updated independently). The output of target_log_prob_fn(*current_state) should sum log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, current_state[0, :] could have a different target distribution from current_state[1, :]. These semantics are governed by target_log_prob_fn(*current_state). (The number of independent chains is tf.size(target_log_prob_fn(*current_state)).)


Simple chain with warm-up.

In this example we sample from a standard univariate normal distribution using HMC with adaptive step size.

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax


# Target distribution is proportional to: `exp(-x (1 + x))`.
def unnormalized_log_prob(x):
  return -x - x**2.

# Initialize the HMC transition kernel.
num_results = int(10e3)
num_burnin_steps = int(1e3)
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    num_adaptation_steps=int(num_burnin_steps * 0.8))

# Run the chain (with burn-in).
def run_chain():
  # Run the chain (with burn-in).
  samples, is_accepted = tfp.mcmc.sample_chain(
      trace_fn=lambda _, pkr: pkr.inner_results.is_accepted)

  sample_mean = tf.reduce_mean(samples)
  sample_stddev = tf.math.reduce_std(samples)
  is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32))
  return sample_mean, sample_stddev, is_accepted

sample_mean, sample_stddev, is_accepted = run_chain()

print('mean:{:.4f}  stddev:{:.4f}  acceptance:{:.4f}'.format(
    sample_mean.numpy(), sample_stddev.numpy(), is_accepted.numpy()))
Estimate parameters of a more complicated posterior.

In this example, we'll use Monte-Carlo EM to find best-fit parameters. See [Convergence of a stochastic approximation version of the EM algorithm][2] for more details.

More precisely, we use HMC to form a chain conditioned on parameter sigma and training data { (x[i], y[i]) : i=1...n }. Then we use one gradient step of maximum-likelihood to improve the sigma estimate. Then repeat the process until convergence. (This procedure is a Robbins--Monro algorithm.)

The generative assumptions are:

  W ~ MVN(loc=0, scale=sigma * eye(dims))
  for i=1...num_samples:
      X[i] ~ MVN(loc=0, scale=eye(dims))
    eps[i] ~ Normal(loc=0, scale=1)
      Y[i] = X[i].T * W + eps[i]

We now implement a stochastic approximation of Expectation Maximization (SAEM) using tensorflow_probability intrinsics. [Bernard (1999)][2]

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import numpy as np


tfd = tfp.distributions

def make_training_data(num_samples, dims, sigma):
  dt = np.asarray(sigma).dtype
  x = np.random.randn(dims, num_samples).astype(dt)
  w = sigma * np.random.randn(1, dims).astype(dt)
  noise = np.random.randn(num_samples).astype(dt)
  y = + noise
  return y[0], x, w[0]

def make_weights_prior(dims, log_sigma):
  return tfd.MultivariateNormalDiag(
      loc=tf.zeros([dims], dtype=log_sigma.dtype),

def make_response_likelihood(w, x):
  if w.shape.ndims == 1:
    y_bar = tf.matmul(w[tf.newaxis], x)[0]
    y_bar = tf.matmul(w, x)
  return tfd.Normal(loc=y_bar, scale=tf.ones_like(y_bar))  # [n]

# Setup assumptions.
dtype = np.float32
num_samples = 500
dims = 10

weights_prior_true_scale = np.array(0.3, dtype)
y, x, _ = make_training_data(
    num_samples, dims, weights_prior_true_scale)

log_sigma = tf.Variable(0., dtype=dtype, name='log_sigma')

optimizer = tf.optimizers.SGD(learning_rate=0.01)

def mcem_iter(weights_chain_start, step_size):
  with tf.GradientTape() as tape:
    prior = make_weights_prior(dims, log_sigma)

    def unnormalized_posterior_log_prob(w):
      likelihood = make_response_likelihood(w, x)
      return (
          prior.log_prob(w) +
          tf.reduce_sum(likelihood.log_prob(y), axis=-1))  # [m]

    def trace_fn(_, pkr):
      return (

    num_results = 2
    weights, (
        log_accept_ratio, target_log_prob, step_size) = tfp.mcmc.sample_chain(
            # Adapt for the entirety of the trajectory.

    # We do an optimization step to propagate `log_sigma` after two HMC
    # steps to propagate `weights`.
    loss = -tf.reduce_mean(target_log_prob)

  avg_acceptance_ratio = tf.math.exp(
      tfp.math.reduce_logmeanexp(tf.minimum(log_accept_ratio, 0.)))

      [[tape.gradient(loss, log_sigma), log_sigma]])

  weights_prior_estimated_scale = tf.math.exp(log_sigma)
  return (weights_prior_estimated_scale, weights[-1], loss,
          step_size[-1], avg_acceptance_ratio)

num_iters = int(40)

weights_prior_estimated_scale_ = np.zeros(num_iters, dtype)
weights_ = np.zeros([num_iters + 1, dims], dtype)
loss_ = np.zeros([num_iters], dtype)
weights_[0] = np.random.randn(dims).astype(dtype)
step_size_ = 0.03

for iter_ in range(num_iters):
      weights_[iter_ + 1],
  ] = mcem_iter(weights_[iter_], step_size_)
      1, ('iter:{:>2}  loss:{: 9.3f}  scale:{:.3f}  '
          'step_size:{:.4f}  avg_acceptance_ratio:{:.4f}').format(
              iter_, loss_[iter_], weights_prior_estimated_scale_[iter_],
              step_size_, avg_acceptance_ratio_))

# Should converge to ~0.22.
import matplotlib.pyplot as plt


[1]: Radford Neal. MCMC Using Hamiltonian Dynamics. Handbook of Markov Chain Monte Carlo, 2011.

[2]: Bernard Delyon, Marc Lavielle, Eric, Moulines. Convergence of a stochastic approximation version of the EM algorithm, Ann. Statist. 27 (1999), no. 1, 94--128.

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
step_size Tensor or Python list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable.
num_leapfrog_steps Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to step_size * num_leapfrog_steps.
state_gradients_are_stopped Python bool indicating that the proposed new state be run through tf.stop_gradient. This is particularly useful when combining optimization over samples from the HMC chain. Default value: False (i.e., do not apply stop_gradient).
step_size_update_fn Python callable taking current step_size (typically a tf.Variable) and kernel_results (typically collections.namedtuple) and returns updated step_size (Tensors). Default value: None (i.e., do not update step_size automatically).
store_parameters_in_results If True, then step_size and num_leapfrog_steps are written to and read from eponymous fields in the kernel results objects returned from one_step and bootstrap_results. This allows wrapper kernels to adjust those parameters on the fly. This is incompatible with step_size_update_fn, which must be set to None.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'hmc_kernel').

experimental_shard_axis_names The shard axis names for members of the state.
is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.


num_leapfrog_steps Returns the num_leapfrog_steps parameter.

If store_parameters_in_results argument to the initializer was set to True, this only returns the value of the num_leapfrog_steps placed in the kernel results by the bootstrap_results method. The actual num_leapfrog_steps in that situation is governed by the previous_kernel_results argument to one_step method.

parameters Return dict of __init__ arguments and their values.

step_size Returns the step_size parameter.

If store_parameters_in_results argument to the initializer was set to True, this only returns the value of the step_size placed in the kernel results by the bootstrap_results method. The actual step size in that situation is governed by the previous_kernel_results argument to one_step method.





View source

Creates initial previous_kernel_results using a supplied state.


View source

Non-destructively creates a deep copy of the kernel.

**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).


View source

Returns a copy of the kernel with the provided shard axis names.

shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

A copy of the current kernel with the shard axis information.


View source

Runs one iteration of Hamiltonian Monte Carlo.

current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)).
previous_kernel_results collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.)
seed PRNG seed; see tfp.random.sanitize_seed for details.

next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
kernel_results collections.namedtuple of internal calculations used to advance the chain.

ValueError if there isn't one step_size or a list with same length as current_state.