View source on GitHub 
Runs one step of the No UTurn Sampler.
Inherits From: TransitionKernel
tfp.experimental.substrates.jax.mcmc.NoUTurnSampler(
target_log_prob_fn, step_size, max_tree_depth=10, max_energy_diff=1000.0,
unrolled_leapfrog_steps=1, parallel_iterations=10, seed=None, name=None
)
The No UTurn Sampler (NUTS) is an adaptive variant of the Hamiltonian Monte
Carlo (HMC) method for MCMC. NUTS adapts the distance traveled in response to
the curvature of the target density. Conceptually, one proposal consists of
reversibly evolving a trajectory through the sample space, continuing until
that trajectory turns back on itself (hence the name, 'No UTurn'). This class
implements one random NUTS step from a given current_state
.
Mathematical details and derivations can be found in
[Hoffman, Gelman (2011)][1] and [Betancourt (2018)][2].
The one_step
function can update multiple chains in parallel. It assumes
that a prefix of leftmost dimensions of current_state
index independent
chain states (and are therefore updated independently). The output of
target_log_prob_fn(*current_state)
should sum logprobabilities across all
event dimensions. Slices along the rightmost dimensions may have different
target distributions; for example, current_state[0][0, ...]
could have a
different target distribution from current_state[0][1, ...]
. These
semantics are governed by target_log_prob_fn(*current_state)
. (The number of
independent chains is tf.size(target_log_prob_fn(*current_state))
.)
References
[1]: Matthew D. Hoffman, Andrew Gelman. The NoUTurn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. 2011. https://arxiv.org/pdf/1111.4246.pdf
[2]: Michael Betancourt. A Conceptual Introduction to Hamiltonian Monte Carlo. arXiv preprint arXiv:1701.02434, 2018. https://arxiv.org/abs/1701.02434
Args  

target_log_prob_fn

Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns its
(possibly unnormalized) logdensity under the target distribution.

step_size

Tensor or Python list of Tensor s representing the step
size for the leapfrog integrator. Must broadcast with the shape of
current_state . Larger step sizes lead to faster progress, but
toolarge step sizes make rejection exponentially more likely. When
possible, it's often helpful to match pervariable step sizes to the
standard deviations of the target distribution in each variable.

max_tree_depth

Maximum depth of the tree implicitly built by NUTS. The
maximum number of leapfrog steps is bounded by 2**max_tree_depth i.e.
the number of nodes in a binary tree max_tree_depth nodes deep. The
default setting of 10 takes up to 1024 leapfrog steps.

max_energy_diff

Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. 
unrolled_leapfrog_steps

The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. 
parallel_iterations

The number of iterations allowed to run in parallel.
It must be a positive integer. See tf.while_loop for more details.

seed

Python integer to seed the random number generator. Deprecated, pass
seed to tfp.mcmc.sample_chain .

name

Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'nuts_kernel').

Attributes  

is_calibrated

Returns True if Markov chain converges to specified distribution.

max_energy_diff


max_tree_depth


name


parallel_iterations


parameters


read_instruction


step_size


target_log_prob_fn


unrolled_leapfrog_steps


write_instruction

Methods
bootstrap_results
bootstrap_results(
init_state
)
Creates initial previous_kernel_results
using a supplied state
.
copy
copy(
**override_parameter_kwargs
)
Nondestructively creates a deep copy of the kernel.
Args  

**override_parameter_kwargs

Python String/value dictionary of
initialization arguments to override with new values.

Returns  

new_kernel

TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .

one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Must be overridden by subclasses.
Args  

current_state

Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).

previous_kernel_results

A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).

seed

Optional, a seed for reproducible sampling. 
Returns  

next_state

Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).

kernel_results

A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
