tfp.experimental.mcmc.windowed_adaptive_nuts

Adapt and sample from a joint distribution using NUTS, conditioned on pins.

Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows.

n_draws int Number of draws after adaptation.
joint_dist tfd.JointDistribution A joint distribution to sample from.
n_chains int Number of independent chains to run MCMC with.
num_adaptation_steps int Number of draws used to adapt step size and
current_state Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as model.experimental_pin(**pins).sample(n_chains).
init_step_size Optional Where to initialize the step size for the leapfrog integrator. The structure should broadcast with current_state. For example, if the initial state is

{'a': tf.zeros(n_chains),
'b': tf.zeros([n_chains, n_features])}

then any of 1., {'a': 1., 'b': 1.}, or {'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])} will work. Defaults to the dimension of the log density to the 0.25 power.

dual_averaging_kwargs Optional dict Keyword arguments to pass to tfp.mcmc.DualAveragingStepSizeAdaptation. By default, a target_accept_prob of 0.85 is set, and the class defaults are used otherwise.
max_tree_depth Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by 2**max_tree_depth i.e. the number of nodes in a binary tree max_tree_depth nodes deep. The default setting of 10 takes up to 1024 leapfrog steps.
max_energy_diff Scalar threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000.
unrolled_leapfrog_steps The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1.
parallel_iterations The number of iterations allowed to run in parallel. It must be a positive integer. See tf.while_loop for more details.
trace_fn Optional callable The trace function should accept the arguments (state, bijector, is_adapting, phmc_kernel_results), where the state is an unconstrained, flattened float tensor, bijector is the tfb.Bijector that is used for unconstraining and flattening, is_adapting is a boolean to mark whether the draw is from an adaptation step, and phmc_kernel_results is the UncalibratedPreconditionedHamiltonianMonteCarloKernelResults from the PreconditionedHamiltonianMonteCarlo kernel. Note that bijector.inverse(state) will provide access to the current draw in the untransformed space, using the structure of the provided joint_dist.
return_final_kernel_results If True, then the final kernel results are returned alongside the chain state and the trace specified by the trace_fn.
discard_tuning bool Whether to return tuning traces and draws.
seed Optional, a seed for reproducible sampling.
**pins These are used to condition the provided joint distribution, and are passed directly to joint_dist.experimental_pin(**pins).

A single structure of draws is returned in case the trace_fn is None, and return_final_kernel_results is False. If there is a trace function, the return value is a tuple, with the trace second. If the return_final_kernel_results is True, the return value is a tuple of length 3, with final kernel results returned last. If discard_tuning is True, the tensors in draws and trace will have length n_draws, otherwise, they will have length n_draws + num_adaptation_steps.