tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation

Use gradient ascent to adapt inner kernel's trajectory length.

Inherits From: TransitionKernel

This kernel optimizes the continuous trajectory length (aka integration time) parameter of Hamiltonian Monte Carlo. It does so by following the gradient of a criterion with respect to the trajectory length. The criterion is computed via criterion_fn with signature (previous_state, proposed_state, accept_prob, trajectory_length) -> criterion, where both the returned values retain the batch dimensions implied by the first three inputs. See chees_criterion for an example.

To avoid resonances, this kernel jitters the integration time between 0 and the learned trajectory length by default.

The initial trajectory length is extracted from the inner HamiltonianMonteCarlo kernel by multiplying the initial step size and initial number of leapfrog steps. This (and other algorithmic details) imply that the step size must be a scalar.

In general, adaptation prevents the chain from reaching a stationary distribution, so obtaining consistent samples requires num_adaptation_steps be set to a value [somewhat smaller][1] than the number of burnin steps. However, it may sometimes be helpful to set num_adaptation_steps to a larger value during development in order to inspect the behavior of the chain during adaptation.

Examples

This implements something similar to ChEES HMC from [2].

import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

target_log_prob_fn = tfd.JointDistributionSequential([
    tfd.Normal(0., 20.),
    tfd.HalfNormal(10.),
]).log_prob

num_burnin_steps = 1000
num_adaptation_steps = int(num_burnin_steps * 0.8)
num_results = 500
num_chains = 16
step_size = 0.1

kernel = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=target_log_prob_fn,
    step_size=step_size,
    num_leapfrog_steps=1,
)
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
    kernel,
    num_adaptation_steps=num_adaptation_steps)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    kernel, num_adaptation_steps=num_adaptation_steps)
kernel = tfp.mcmc.TransformedTransitionKernel(
    kernel,
    [tfb.Identity(),
     tfb.Exp()])

def trace_fn(_, pkr):
  return (
      pkr.inner_results.inner_results.inner_results.accepted_results
      .step_size,
      pkr.inner_results.inner_results.max_trajectory_length,
      pkr.inner_results.inner_results.inner_results.log_accept_ratio,
  )

# The chain will be stepped for num_results + num_burnin_steps, adapting for
# the first num_adaptation_steps.
samples, [step_size, max_trajectory_length, log_accept_ratio] = (
    tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=num_burnin_steps,
        current_state=[tf.zeros(num_chains),
                       tf.zeros(num_chains)],
        kernel=kernel,
        trace_fn=trace_fn,))

# ~0.75
accept_prob = tf.math.exp(tfp.math.reduce_logmeanexp(
    tf.minimum(log_accept_ratio, 0.)))

References

[1]: <http://andrewgelman.com/2017/12/15/ burn-vs-warm-iterative-simulation-algorithms/#comment_627745>

[2]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. <https://proceedings.mlr.press/v130/hoffman21a>

inner_kernel TransitionKernel-like object.
num_adaptation_steps Scalar int Tensor number of initial steps to during which to adjust the trajectory length. This may be greater, less than, or equal to the number of burnin steps.
use_halton_sequence_jitter Python bool. Whether to use a Halton sequence for jittering the trajectory length. This makes the procedure more stable than sampling trajectory lengths from a uniform distribution.
adaptation_rate Floating point scalar Tensor. How rapidly to adapt the trajectory length.
jitter_amount Floating point scalar Tensor. How much to jitter the trajectory on the next step. The trajectory length is sampled from `[(1

  • jitter_amount) * max_trajectory_length, max_trajectory_length]. </td> </tr><tr> <td>criterion_fn<a id="criterion_fn"></a> </td> <td> Callable with(previous_state, proposed_state, accept_prob) -> criterion. Computes the criterion value. </td> </tr><tr> <td>max_leapfrog_steps<a id="max_leapfrog_steps"></a> </td> <td> Int32 scalarTensor. Clips the number of leapfrog steps to this value. </td> </tr><tr> <td>averaged_sq_grad_adaptation_rate<a id="averaged_sq_grad_adaptation_rate"></a> </td> <td> Floating point scalarTensor. How rapidly to adapt the running average squared gradient. This is1 - beta_2from Adam. </td> </tr><tr> <td>num_leapfrog_steps_getter_fn<a id="num_leapfrog_steps_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> num_leapfrog_stepswherekernel_resultsare the results of theinner_kernel, andnum_leapfrog_stepsis a floating pointTensor. </td> </tr><tr> <td>num_leapfrog_steps_setter_fn<a id="num_leapfrog_steps_setter_fn"></a> </td> <td> A callable with the signature(kernel_results, new_num_leapfrog_steps) -> new_kernel_resultswherekernel_resultsare the results of theinner_kernel,new_num_leapfrog_stepsis a scalar tensorTensor, andnew_kernel_resultsare a copy ofkernel_resultswith the number of leapfrog steps set. </td> </tr><tr> <td>step_size_getter_fn<a id="step_size_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> step_sizewherekernel_resultsare the results of theinner_kernel, andstep_sizeis a floating pointTensor. </td> </tr><tr> <td>proposed_velocity_getter_fn<a id="proposed_velocity_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> proposed_velocitywherekernel_resultsare the results of theinner_kernel, andproposed_velocityis a (possibly nested) floating pointTensor. Velocity is derivative of state with respect to trajectory length. </td> </tr><tr> <td>log_accept_prob_getter_fn<a id="log_accept_prob_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> log_accept_probwherekernel_resultsare the results of theinner_kernel, andlog_accept_probis a floating pointTensor.log_accept_probhas shape[C0, ...., Cb]withb > 0. </td> </tr><tr> <td>proposed_state_getter_fn<a id="proposed_state_getter_fn"></a> </td> <td> A callable with the signature(kernel_results) -> proposed_statewherekernel_resultsare the results of theinner_kernel, andproposed_stateis a (possibly nested) floating pointTensor. </td> </tr><tr> <td>validate_args<a id="validate_args"></a> </td> <td> Pythonbool. WhenTruekernel parameters are checked for validity. WhenFalseinvalid inputs may silently render incorrect outputs. </td> </tr><tr> <td>experimental_shard_axis_names<a id="experimental_shard_axis_names"></a> </td> <td> A structure of string names indicating how members of the state are sharded. </td> </tr><tr> <td>experimental_reduce_chain_axis_names<a id="experimental_reduce_chain_axis_names"></a> </td> <td> A string or list of string names indicating how batches of chains are sharded. </td> </tr><tr> <td>name<a id="name"></a> </td> <td> Pythonstr` name prefixed to Ops created by this class. Default: 'simple_step_size_adaptation'.

ValueError If inner_kernel contains a TransformedTransitionKernel in its hierarchy. If you need to use the TransformedTransitionKernel, place it above this kernel in the hierarchy (see the example in the class docstring).

averaged_sq_grad_adaptation_rate

experimental_reduce_chain_axis_names

experimental_shard_axis_names The shard axis names for members of the state.
inner_kernel

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

max_leapfrog_steps

name

num_adaptation_steps

parameters Return dict of __init__ arguments and their values.
use_halton_sequence_jitter

validate_args

Methods

bootstrap_results

View source

Returns an object with the same type as returned by one_step(...)[1].

Args
init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

criterion_fn

View source

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

log_accept_prob_getter_fn

View source

num_leapfrog_steps_getter_fn

View source

num_leapfrog_steps_setter_fn

View source

one_step

View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

proposed_state_getter_fn

View source

proposed_velocity_getter_fn

View source

step_size_getter_fn

View source