Google I/O is a wrap! Catch up on TensorFlow sessions

Inherits From: `TransitionKernel`

This kernel optimizes the continuous trajectory length (aka integration time) parameter of Hamiltonian Monte Carlo. It does so by following the gradient of a criterion with respect to the trajectory length. The criterion is computed via `criterion_fn` with signature ```(previous_state, proposed_state, accept_prob, trajectory_length) -> criterion```, where both the returned values retain the batch dimensions implied by the first three inputs. See `chees_criterion` for an example.

To avoid resonances, this kernel jitters the integration time between 0 and the learned trajectory length by default.

The initial trajectory length is extracted from the inner `HamiltonianMonteCarlo` kernel by multiplying the initial step size and initial number of leapfrog steps. This (and other algorithmic details) imply that the step size must be a scalar.

In general, adaptation prevents the chain from reaching a stationary distribution, so obtaining consistent samples requires `num_adaptation_steps` be set to a value [somewhat smaller][1] than the number of burnin steps. However, it may sometimes be helpful to set `num_adaptation_steps` to a larger value during development in order to inspect the behavior of the chain during adaptation.

Examples

This implements something similar to ChEES HMC from [2].

``````import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

target_log_prob_fn = tfd.JointDistributionSequential([
tfd.Normal(0., 20.),
tfd.HalfNormal(10.),
]).log_prob

num_burnin_steps = 1000
num_results = 500
num_chains = 16
step_size = 0.1

kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=step_size,
num_leapfrog_steps=1,
)
kernel,
kernel = tfp.mcmc.TransformedTransitionKernel(
kernel,
[tfb.Identity(),
tfb.Exp()])

def trace_fn(_, pkr):
return (
pkr.inner_results.inner_results.inner_results.accepted_results
.step_size,
pkr.inner_results.inner_results.max_trajectory_length,
pkr.inner_results.inner_results.inner_results.log_accept_ratio,
)

# The chain will be stepped for num_results + num_burnin_steps, adapting for
samples, [step_size, max_trajectory_length, log_accept_ratio] = (
tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=[tf.zeros(num_chains),
tf.zeros(num_chains)],
kernel=kernel,
trace_fn=trace_fn,))

# ~0.75
accept_prob = tf.math.exp(tfp.math.reduce_logmeanexp(
tf.minimum(log_accept_ratio, 0.)))
``````

References

[1]: <http://andrewgelman.com/2017/12/15/ burn-vs-warm-iterative-simulation-algorithms/#comment_627745>

[2]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. <https://proceedings.mlr.press/v130/hoffman21a>

`inner_kernel` `TransitionKernel`-like object.
`num_adaptation_steps` Scalar `int` `Tensor` number of initial steps to during which to adjust the trajectory length. This may be greater, less than, or equal to the number of burnin steps.
`use_halton_sequence_jitter` Python bool. Whether to use a Halton sequence for jittering the trajectory length. This makes the procedure more stable than sampling trajectory lengths from a uniform distribution.
`adaptation_rate` Floating point scalar `Tensor`. How rapidly to adapt the trajectory length.
`jitter_amount` Floating point scalar `Tensor`. How much to jitter the trajectory on the next step. The trajectory length is sampled from `[(1

• jitter_amount) * max_trajectory_length, max_trajectory_length]```. </td> </tr><tr> <td>```criterion_fn```</td> <td> Callable with```(previous_state, proposed_state, accept_prob) -> criterion```. Computes the criterion value. </td> </tr><tr> <td>```max_leapfrog_steps```</td> <td> Int32 scalar```Tensor```. Clips the number of leapfrog steps to this value. </td> </tr><tr> <td>```averaged_sq_grad_adaptation_rate```</td> <td> Floating point scalar```Tensor```. How rapidly to adapt the running average squared gradient. This is```1 - beta_2```from Adam. </td> </tr><tr> <td>```num_leapfrog_steps_getter_fn```</td> <td> A callable with the signature```(kernel_results) -> num_leapfrog_steps`where`kernel_results```are the results of the```inner_kernel`, and`num_leapfrog_steps```is a floating point```Tensor```. </td> </tr><tr> <td>```num_leapfrog_steps_setter_fn```</td> <td> A callable with the signature```(kernel_results, new_num_leapfrog_steps) -> new_kernel_results`where`kernel_results`are the results of the`inner_kernel`,`new_num_leapfrog_steps`is a scalar tensor`Tensor`, and`new_kernel_results`are a copy of`kernel_results```with the number of leapfrog steps set. </td> </tr><tr> <td>```step_size_getter_fn```</td> <td> A callable with the signature```(kernel_results) -> step_size`where`kernel_results`are the results of the`inner_kernel```, and```step_size`is a floating point`Tensor```. </td> </tr><tr> <td>```proposed_velocity_getter_fn```</td> <td> A callable with the signature```(kernel_results) -> proposed_velocity`where`kernel_results```are the results of the```inner_kernel`, and`proposed_velocity```is a (possibly nested) floating point```Tensor```. Velocity is derivative of state with respect to trajectory length. </td> </tr><tr> <td>```log_accept_prob_getter_fn```</td> <td> A callable with the signature```(kernel_results) -> log_accept_prob`where`kernel_results`are the results of the`inner_kernel`, and`log_accept_prob`is a floating point`Tensor`.`log_accept_prob`has shape`[C0, ...., Cb]`with`b > 0```. </td> </tr><tr> <td>```proposed_state_getter_fn```</td> <td> A callable with the signature```(kernel_results) -> proposed_state`where`kernel_results`are the results of the`inner_kernel`, and`proposed_state```is a (possibly nested) floating point```Tensor```. </td> </tr><tr> <td>```validate_args```</td> <td> Python```bool`. When`True```kernel parameters are checked for validity. When```False```invalid inputs may silently render incorrect outputs. </td> </tr><tr> <td>```experimental_shard_axis_names```</td> <td> A structure of string names indicating how members of the state are sharded. </td> </tr><tr> <td>```experimental_reduce_chain_axis_names```</td> <td> A string or list of string names indicating how batches of chains are sharded. </td> </tr><tr> <td>```name```</td> <td> Python```str` name prefixed to Ops created by this class. Default: 'simple_step_size_adaptation'.

`ValueError` If `inner_kernel` contains a `TransformedTransitionKernel` in its hierarchy. If you need to use the `TransformedTransitionKernel`, place it above this kernel in the hierarchy (see the example in the class docstring).

`averaged_sq_grad_adaptation_rate`

`experimental_reduce_chain_axis_names`

`experimental_shard_axis_names` The shard axis names for members of the state.
`inner_kernel`

`is_calibrated` Returns `True` if Markov chain converges to specified distribution.

`TransitionKernel`s which are "uncalibrated" are often calibrated by composing them with the `tfp.mcmc.MetropolisHastings` `TransitionKernel`.

`max_leapfrog_steps`

`name`

`num_adaptation_steps`

`parameters` Return `dict` of `__init__` arguments and their values.
`use_halton_sequence_jitter`

`validate_args`

Methods

`bootstrap_results`

View source

Returns an object with the same type as returned by `one_step(...)[1]`.

Args
`init_state` `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s).

Returns
`kernel_results` A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function.

`copy`

View source

Non-destructively creates a deep copy of the kernel.

Args
`**override_parameter_kwargs` Python String/value `dictionary` of initialization arguments to override with new values.

Returns
`new_kernel` `TransitionKernel` object of same type as `self`, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., `dict(self.parameters, **override_parameters_kwargs)`.

View source

`experimental_with_shard_axes`

View source

Returns a copy of the kernel with the provided shard axis names.

Args
`shard_axis_names` a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

View source

View source

View source

`one_step`

View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

Args
`current_state` `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s).
`previous_kernel_results` A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`).
`seed` PRNG seed; see `tfp.random.sanitize_seed` for details.

Returns
`next_state` `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s).
`kernel_results` A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function.

View source

View source

`step_size_getter_fn`

View source

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]