tfp.substrates.jax.mcmc.CheckpointableStatesAndTrace

States and auxiliary trace of an MCMC chain.

The first dimension of all the Tensors in the all_states and trace attributes is the same and represents the chain length.

all_states A Tensor or a nested collection of Tensors representing the MCMC chain state.
trace A Tensor or a nested collection of Tensors representing the auxiliary values traced alongside the chain.
final_kernel_results A Tensor or a nested collection of Tensors representing the final value of the auxiliary state of the TransitionKernel that generated this chain.