tfp.experimental.mcmc.windowed_adaptive_nuts

Adapt and sample from a joint distribution using NUTS, conditioned on pins.

Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows.

n_draws int Number of draws after adaptation.
joint_dist tfd.JointDistribution A joint distribution to sample from.
n_chains int or list of ints Number of independent chains to run MCMC with.
num_adaptation_steps int Number of draws used to adapt step size and mass matrix.
current_state Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as model.experimental_pin(**pins).sample(n_chains).
init_step_size Optional Where to initialize the step size for the leapfrog integrator. The structure should broadcast with current_state. For example, if the initial state is

{'a': tf.zeros(n_chains),
 'b': tf.zeros([n_chains, n_features])}
 ```
then any of `1.`, `{'a': 1., 'b': 1.}`, or
`{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}` will
work. Defaults to the dimension of the log density to the 0.25 power.
</td>
</tr><tr>
<td>
`dual_averaging_kwargs`<a id="dual_averaging_kwargs"></a>
</td>
<td>
Optional dict
Keyword arguments to pass to <a href="../../../tfp/mcmc/DualAveragingStepSizeAdaptation"><code>tfp.mcmc.DualAveragingStepSizeAdaptation</code></a>.
By default, a `target_accept_prob` of 0.85 is set, acceptance
probabilities across chains are reduced using a harmonic mean, and the
class defaults are used otherwise.
</td>
</tr><tr>
<td>
`max_tree_depth`<a id="max_tree_depth"></a>
</td>
<td>
Maximum depth of the tree implicitly built by NUTS. The
maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
the number of nodes in a binary tree `max_tree_depth` nodes deep. The
default setting of 10 takes up to 1024 leapfrog steps.
</td>
</tr><tr>
<td>
`max_energy_diff`<a id="max_energy_diff"></a>
</td>
<td>
Scalar threshold of energy differences at each leapfrog,
divergence samples are defined as leapfrog steps that exceed this
threshold. Default to 1000.
</td>
</tr><tr>
<td>
`unrolled_leapfrog_steps`<a id="unrolled_leapfrog_steps"></a>
</td>
<td>
The number of leapfrogs to unroll per tree
expansion step. Applies a direct linear multipler to the maximum
trajectory length implied by max_tree_depth. Defaults to 1.
</td>
</tr><tr>
<td>
`parallel_iterations`<a id="parallel_iterations"></a>
</td>
<td>
The number of iterations allowed to run in parallel.
It must be a positive integer. See <a href="https://www.tensorflow.org/api_docs/python/tf/while_loop"><code>tf.while_loop</code></a> for more details.
</td>
</tr><tr>
<td>
`trace_fn`<a id="trace_fn"></a>
</td>
<td>
Optional callable
The trace function should accept the arguments
`(state, bijector, is_adapting, phmc_kernel_results)`,  where the `state`
is an unconstrained, flattened float tensor, `bijector` is the
`tfb.Bijector` that is used for unconstraining and flattening,
`is_adapting` is a boolean to mark whether the draw is from an adaptation
step, and `phmc_kernel_results` is the
`UncalibratedPreconditionedHamiltonianMonteCarloKernelResults` from the
`PreconditionedHamiltonianMonteCarlo` kernel. Note that
`bijector.inverse(state)` will provide access to the current draw in the
untransformed space, using the structure of the provided `joint_dist`.
</td>
</tr><tr>
<td>
`return_final_kernel_results`<a id="return_final_kernel_results"></a>
</td>
<td>
If `True`, then the final kernel results are
returned alongside the chain state and the trace specified by the
`trace_fn`.
</td>
</tr><tr>
<td>
`discard_tuning`<a id="discard_tuning"></a>
</td>
<td>
bool
Whether to return tuning traces and draws.
</td>
</tr><tr>
<td>
`chain_axis_names`<a id="chain_axis_names"></a>
</td>
<td>
A `str` or list of `str`s indicating the named axes
by which multiple chains are sharded. See <a href="../../../tfp/experimental/mcmc/Sharded"><code>tfp.experimental.mcmc.Sharded</code></a>
for more context.
</td>
</tr><tr>
<td>
`seed`<a id="seed"></a>
</td>
<td>
PRNG seed; see <a href="../../../tfp/random/sanitize_seed"><code>tfp.random.sanitize_seed</code></a> for details.
</td>
</tr><tr>
<td>
`**pins`<a id="**pins"></a>
</td>
<td>
  These are used to condition the provided joint distribution, and are
passed directly to `joint_dist.experimental_pin(**pins)`.
</td>
</tr>
</table>



<!-- Tabular view -->
 <table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2"><h2 class="add-link">Returns</h2></th></tr>
<tr class="alt">
<td colspan="2">
A single structure of draws is returned in case the trace_fn is `None`, and
`return_final_kernel_results` is `False`. If there is a trace function,
the return value is a tuple, with the trace second. If the
`return_final_kernel_results` is `True`, the return value is a tuple of
length 3, with final kernel results returned last. If `discard_tuning` is
`True`, the tensors in `draws` and `trace` will have length `n_draws`,
otherwise, they will have length `n_draws + num_adaptation_steps`.
</td>
</tr>

</table>