tfp.experimental.mcmc.Sharded

Shards a transition kernel across a named axis.

Inherits From: TransitionKernel

Ordinarily, one can produce independent Markov chains from a single kernel by proving a batch of states but when using named axes inside of a map (say in the case of using JAX's pmap, vmap, or xmap), the kernel is provided with state without batch dimensions. In order to sample independently across the named axis, the PRNG seed across the named axis must be different. This can be accomplished by folding the named axis index into the random seed. A Sharded kernel does exactly this, creating independent chains across a named axis.

inner_kernel A TransitionKernel to be sharded.
chain_axis_names A str or list of strs that determine the named axes that independent Markov chains will be sharded across.
validate_args Python bool. When True kernel parameters are checked for validity. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this class.

chain_axis_names

experimental_shard_axis_names The shard axis names for members of the state.
inner_kernel

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

Methods

bootstrap_results

View source

Returns an object with the same type as returned by one_step(...)[1].

Args
init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

one_step

View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.