Ordinarily, one can produce independent Markov chains from a single kernel by
proving a batch of states but when using named axes inside of a map (say
in the case of using JAX's pmap, vmap, or xmap), the kernel is provided
with state without batch dimensions. In order to sample independently across
the named axis, the PRNG seed across the named axis must be different. This
can be accomplished by folding the named axis index into the random seed.
A Sharded kernel does exactly this, creating independent chains across a
A TransitionKernel to be sharded.
A str or list of strs that determine the named axes
that independent Markov chains will be sharded across.
Python bool. When True kernel parameters are checked
for validity. When False invalid inputs may silently render incorrect
Python str name prefixed to Ops created by this class.
The shard axis names for members of the state.
Returns True if Markov chain converges to specified distribution.
TransitionKernels which are "uncalibrated" are often calibrated by
composing them with the tfp.mcmc.MetropolisHastingsTransitionKernel.
Non-destructively creates a deep copy of the kernel.
Python String/value dictionary of
initialization arguments to override with new values.
TransitionKernel object of same type as self,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,