ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


<!-- Stable --> <table class="tfo-notebook-buttons tfo-api nocontent" align="left"> <td> <a target="_blank" href=""> <img src="" /> View source on GitHub </a> </td> </table> Adapt and sample from a joint distribution using NUTS, conditioned on pins. <pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> <code>tfp.experimental.mcmc.windowed_adaptive_nuts( n_draws, joint_dist, *, n_chains=64, num_adaptation_steps=500, current_state=None, init_step_size=None, dual_averaging_kwargs=None, max_tree_depth=10, max_energy_diff=500.0, unrolled_leapfrog_steps=1, parallel_iterations=10, trace_fn=_default_nuts_trace_fn, return_final_kernel_results=False, discard_tuning=True, chain_axis_names=None, seed=None, **pins ) </code></pre> <!-- Placeholder for "Used in" --> Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows. <!-- Tabular view --> <table class="responsive fixed orange"> <colgroup><col width="214px"><col></colgroup> <tr><th colspan="2"><h2 class="add-link">Args</h2></th></tr> <tr> <td> `n_draws` </td> <td> int Number of draws after adaptation. </td> </tr><tr> <td> `joint_dist` </td> <td> `tfd.JointDistribution` A joint distribution to sample from. </td> </tr><tr> <td> `n_chains` </td> <td> int or list of ints Number of independent chains to run MCMC with. </td> </tr><tr> <td> `num_adaptation_steps` </td> <td> int Number of draws used to adapt step size and mass matrix. </td> </tr><tr> <td> `current_state` </td> <td> Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as `model.experimental_pin(**pins).sample(n_chains)`. </td> </tr><tr> <td> `init_step_size` </td> <td> Optional Where to initialize the step size for the leapfrog integrator. The structure should broadcast with `current_state`. For example, if the initial state is

{'a': tf.zeros(n_chains), 'b': tf.zeros([n_chains, n_features])} ``then any of1.,{'a': 1., 'b': 1.}, or{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}will work. Defaults to the dimension of the log density to the 0.25 power. </td> </tr><tr> <td>dual_averaging_kwargs</td> <td> Optional dict Keyword arguments to pass to <a href="../../../tfp/mcmc/DualAveragingStepSizeAdaptation"><code>tfp.mcmc.DualAveragingStepSizeAdaptation</code></a>. By default, atarget_accept_probof 0.85 is set, acceptance probabilities across chains are reduced using a harmonic mean, and the class defaults are used otherwise. </td> </tr><tr> <td>max_tree_depth</td> <td> Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by2max_tree_depthi.e. the number of nodes in a binary treemax_tree_depthnodes deep. The default setting of 10 takes up to 1024 leapfrog steps. </td> </tr><tr> <td>max_energy_diff</td> <td> Scalar threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. </td> </tr><tr> <td>unrolled_leapfrog_steps</td> <td> The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. </td> </tr><tr> <td>parallel_iterations</td> <td> The number of iterations allowed to run in parallel. It must be a positive integer. See <a href=""><code>tf.while_loop</code></a> for more details. </td> </tr><tr> <td>trace_fn</td> <td> Optional callable The trace function should accept the arguments(state, bijector, is_adapting, phmc_kernel_results), where thestateis an unconstrained, flattened float tensor,bijectoris thetfb.Bijectorthat is used for unconstraining and flattening,is_adaptingis a boolean to mark whether the draw is from an adaptation step, andphmc_kernel_resultsis theUncalibratedPreconditionedHamiltonianMonteCarloKernelResultsfrom thePreconditionedHamiltonianMonteCarlokernel. Note thatbijector.inverse(state)will provide access to the current draw in the untransformed space, using the structure of the providedjoint_dist. </td> </tr><tr> <td>return_final_kernel_results</td> <td> IfTrue, then the final kernel results are returned alongside the chain state and the trace specified by thetrace_fn. </td> </tr><tr> <td>discard_tuning</td> <td> bool Whether to return tuning traces and draws. </td> </tr><tr> <td>chain_axis_names</td> <td> Astror list ofstrs indicating the named axes by which multiple chains are sharded. See <a href="../../../tfp/experimental/mcmc/Sharded"><code>tfp.experimental.mcmc.Sharded</code></a> for more context. </td> </tr><tr> <td>seed</td> <td> PRNG seed; see <a href="../../../tfp/random/sanitize_seed"><code>tfp.random.sanitize_seed</code></a> for details. </td> </tr><tr> <td>pins</td> <td> These are used to condition the provided joint distribution, and are passed directly tojoint_dist.experimental_pin(**pins)`.

A single structure of draws is returned in case the trace_fn is None, and return_final_kernel_results is False. If there is a trace function, the return value is a tuple, with the trace second. If the return_final_kernel_results is True, the return value is a tuple of length 3, with final kernel results returned last. If discard_tuning is True, the tensors in draws and trace will have length n_draws, otherwise, they will have length n_draws + num_adaptation_steps.