tfp.substrates.jax.mcmc.UncalibratedLangevin

Runs one step of Uncalibrated Langevin discretized diffusion.

Inherits From: TransitionKernel

The class generates a Langevin proposal using _euler_method function and also computes helper UncalibratedLangevinKernelResults for the next iteration.

For more details on UncalibratedLangevin, see MetropolisAdjustedLangevinAlgorithm.

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
step_size Tensor or Python list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable.
volatility_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns volatility value at current_state. Should return a Tensor or Python list of Tensors that must broadcast with the shape of current_state Defaults to the identity function.
parallel_iterations the number of coordinates for which the gradients of the volatility matrix volatility_fn can be computed in parallel.
compute_acceptance Python 'bool' indicating whether to compute the Metropolis log-acceptance ratio used to construct MetropolisAdjustedLangevinAlgorithm kernel.
experimental_shard_axis_names A structure of string names indicating how members of the state are sharded.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'mala_kernel').

ValueError if there isn't one step_size or a list with same length as current_state.
TypeError if volatility_fn is not callable.

compute_acceptance

experimental_shard_axis_names The shard axis names for members of the state.
is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

parallel_iterations

parameters Return dict of __init__ arguments and their values.
step_size

target_log_prob_fn

volatility_fn

Methods

bootstrap_results

View source

Creates initial previous_kernel_results using a supplied state.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

one_step

View source

Runs one iteration of MALA.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)).
previous_kernel_results collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.)
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
kernel_results collections.namedtuple of internal calculations used to advance the chain.

Raises
ValueError if there isn't one step_size or a list with same length as current_state or diffusion_drift.