Adapt and sample from a joint distribution, conditioned on pins.

This uses Hamiltonian Monte Carlo to do the sampling. Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows.

n_draws int Number of draws after adaptation.
joint_dist tfd.JointDistribution A joint distribution to sample from.
num_leapfrog_steps int Number of leapfrog steps to use for the Hamiltonian Monte Carlo step.
n_chains int Number of independent chains to run MCMC with.
num_adaptation_steps int Number of draws used to adapt step size and
current_state Optional Structure of tensors at which to initialize sampling. Should have the same shape and structure as model.experimental_pin(**pins).sample(n_chains).
init_step_size Optional Where to initialize the step size for the leapfrog integrator. The structure should broadcast with current_state. For example, if the initial state is

{'a': tf.zeros(n_chains),
'b': tf.zeros([n_chains, n_features])}

then any of 1., {'a': 1., 'b': 1.}, or {'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])} will work. Defaults to the dimension of the log density to the 0.25 power.

dual_averaging_kwargs Optional dict Keyword arguments to pass to tfp.mcmc.DualAveragingStepSizeAdaptation. By default, a target_accept_prob of 0.75 is set, and the class defaults are used otherwise.
trace_fn Optional callable The trace function should accept the arguments (state, bijector, is_adapting, phmc_kernel_results), where the state is an unconstrained, flattened float tensor, bijector is the tfb.Bijector that is used for unconstraining and flattening, is_adapting is a boolean to mark whether the draw is from an adaptation step, and phmc_kernel_results is the UncalibratedPreconditionedHamiltonianMonteCarloKernelResults from the PreconditionedHamiltonianMonteCarlo kernel. Note that bijector.inverse(state) will provide access to the current draw in the untransformed space, using the structure of the provided joint_dist.
return_final_kernel_results If True, then the final kernel results are returned alongside the chain state and the trace specified by the trace_fn.
discard_tuning bool Whether to return tuning traces and draws.
seed Optional, a seed for reproducible sampling.
**pins These are used to condition the provided joint distribution, and are passed directly to joint_dist.experimental_pin(**pins).

A single structure of draws is returned in case the trace_fn is None, and return_final_kernel_results is False. If there is a trace function, the return value is a tuple, with the trace second. If the return_final_kernel_results is True, the return value is a tuple of length 3, with final kernel results returned last. If discard_tuning is True, the tensors in draws and trace will have length n_draws, otherwise, they will have length n_draws + num_adaptation_steps.