Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

Use gradient ascent to adapt inner kernel's trajectory length.

Inherits From: TransitionKernel

This kernel optimizes the continuous trajectory length (aka integration time) parameter of Hamiltonian Monte Carlo. It does so by following the gradient of a criterion with respect to the trajectory length. The criterion is computed via criterion_fn with signature (previous_state, proposed_state, accept_prob, trajectory_length) -> criterion, where both the returned values retain the batch dimensions implied by the first three inputs. See chees_criterion for an example.

To avoid resonances, this kernel jitters the integration time between 0 and the learned trajectory length by default.

The initial trajectory length is extracted from the inner HamiltonianMonteCarlo kernel by multiplying the initial step size and initial number of leapfrog steps. This (and other algorithmic details) imply that the step size must be a scalar.

In general, adaptation prevents the chain from reaching a stationary distribution, so obtaining consistent samples requires num_adaptation_steps be set to a value [somewhat smaller][1] than the number of burnin steps. However, it may sometimes be helpful to set num_adaptation_steps to a larger value during development in order to inspect the behavior of the chain during adaptation.


This implements something similar to ChEES HMC from [2].

import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

target_log_prob_fn = tfd.JointDistributionSequential([
    tfd.Normal(0., 20.),

num_burnin_steps = 1000
num_adaptation_steps = int(num_burnin_steps * 0.8)
num_results = 500
num_chains = 16
step_size = 0.1

kernel = tfp.mcmc.HamiltonianMonteCarlo(
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    kernel, num_adaptation_steps=num_adaptation_steps)
kernel = tfp.mcmc.TransformedTransitionKernel(

def trace_fn(_, pkr):
  return (

# The chain will be stepped for num_results + num_burnin_steps, adapting for
# the first num_adaptation_steps.
samples, [step_size, max_trajectory_length, log_accept_ratio] = (

# ~0.75
accept_prob = tf.math.exp(tfp.math.reduce_logmeanexp(
    tf.minimum(log_accept_ratio, 0.)))


[1]: < burn-vs-warm-iterative-simulation-algorithms/#comment_627745>

[2]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. <>

inner_kernel TransitionKernel-like object.
num_adaptation_steps Scalar int Tensor number of initial steps to during which to adjust the trajectory length. This may be greater, less than, or equal to the number of burnin steps.
use_halton_sequence_jitter Python bool. Whether to use a Halton sequence for jittering the trajectory length. This makes the procedure more stable than sampling trajectory lengths from a uniform distribution.
adaptation_rate Floating point scalar Tensor. How rapidly to adapt the trajectory length.
jitter_amount Floating point scalar Tensor. How much to jitter the trajectory on the next step. The trajectory length is sampled from `[(1

  • jitter_amount) * max_trajectory_length, max_trajectory_length]. </td> </tr><tr> <td>criterion_fn<a id="criterion_fn"></a> </td> <td> Callable with(previous_state, proposed_state, accept_prob) -> criterion. Computes the criterion value. </td> </tr><tr> <td>max_leapfrog_steps<a id="max_leapfrog_steps"></a> </td> <td> Int32 scalarTensor. Clips the number of leapfrog steps to this value. </td> </tr><tr> <td>averaged_sq_grad_adaptation_rate<a id="averaged_sq_grad_adaptation_rate"></a> </td> <td> Floating point scalarTensor. How rapidly to adapt the running average squared gradient. This is1 - beta_2from Adam. </td> </tr><tr> <td>num_leapfrog_steps_getter_fn<a id="num_leapfrog_steps_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> num_leapfrog_stepswherekernel_resultsare the results of theinner_kernel, andnum_leapfrog_stepsis a floating pointTensor. </td> </tr><tr> <td>num_leapfrog_steps_setter_fn<a id="num_leapfrog_steps_setter_fn"></a> </td> <td> A callable with the signature(kernel_results, new_num_leapfrog_steps) -> new_kernel_resultswherekernel_resultsare the results of theinner_kernel,new_num_leapfrog_stepsis a scalar tensorTensor, andnew_kernel_resultsare a copy ofkernel_resultswith the number of leapfrog steps set. </td> </tr><tr> <td>step_size_getter_fn<a id="step_size_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> step_sizewherekernel_resultsare the results of theinner_kernel, andstep_sizeis a floating pointTensor. </td> </tr><tr> <td>proposed_velocity_getter_fn<a id="proposed_velocity_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> proposed_velocitywherekernel_resultsare the results of theinner_kernel, andproposed_velocityis a (possibly nested) floating pointTensor. Velocity is derivative of state with respect to trajectory length. </td> </tr><tr> <td>log_accept_prob_getter_fn<a id="log_accept_prob_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> log_accept_probwherekernel_resultsare the results of theinner_kernel, andlog_accept_probis a floating pointTensor.log_accept_probhas shape[C0, ...., Cb]withb > 0. </td> </tr><tr> <td>proposed_state_getter_fn<a id="proposed_state_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> proposed_statewherekernel_resultsare the results of theinner_kernel, andproposed_stateis a (possibly nested) floating pointTensor. </td> </tr><tr> <td>validate_args<a id="validate_args"></a> </td> <td> Pythonbool. WhenTruekernel parameters are checked for validity. WhenFalseinvalid inputs may silently render incorrect outputs. </td> </tr><tr> <td>experimental_shard_axis_names<a id="experimental_shard_axis_names"></a> </td> <td> A structure of string names indicating how members of the state are sharded. </td> </tr><tr> <td>experimental_reduce_chain_axis_names<a id="experimental_reduce_chain_axis_names"></a> </td> <td> A string or list of string names indicating how batches of chains are sharded. </td> </tr><tr> <td>name<a id="name"></a> </td> <td> Pythonstr` name prefixed to Ops created by this class. Default: 'simple_step_size_adaptation'.

ValueError If inner_kernel contains a TransformedTransitionKernel in its hierarchy. If you need to use the TransformedTransitionKernel, place it above this kernel in the hierarchy (see the example in the class docstring).



experimental_shard_axis_names The shard axis names for members of the state.

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.




parameters Return dict of __init__ arguments and their values.




View source

Returns an object with the same type as returned by one_step(...)[1].

init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.


View source

Non-destructively creates a deep copy of the kernel.

**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).


View source


View source

Returns a copy of the kernel with the provided shard axis names.

shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

A copy of the current kernel with the shard axis information.


View source