tfp.experimental.mcmc.DiagonalMassMatrixAdaptation

Adapts the inner kernel's momentum_distribution to estimated variance.

Inherits From: TransitionKernel

This kernel uses an online variance estimate to adjust a diagonal covariance matrix for each of the state parts. More specifically, the momentum_distribution of the innermost kernel is set to a diagonal multivariate normal distribution whose variance is the inverse of the online estimate. The inverse of the covariance of the momentum is often called the "mass matrix" in the context of Hamiltonian Monte Carlo.

This preconditioning scheme works well when the covariance is diagonally dominant, and may give reasonable results even when the number of draws is less than the dimension. In particular, it should generally do a better job than no preconditioning, which implicitly uses an identity mass matrix.

Note that this kernel does not implement a calibrated sampler; rather, it is intended to be used as one step of an iterative adaptation process. It should not be used when drawing actual samples.

inner_kernel TransitionKernel-like object.
initial_running_variance tfp.experimental.stats.RunningVariance-like object, or list of them, for a batch of momentum distributions. These use update on the state to maintain an estimate of the variance, and so space, and so must have a structure compatible with the state space.
num_estimation_steps An optional scalar int Tensor number of initial steps to during which to adjust the running variance. This may be greater, less than, or equal to the number of burnin steps. If this argument is None, the mass matrix will be updated at each one_step call. Otherwise, the mass matrix will be updated when the current step is equal to num_estimation_steps.
momentum_distribution_setter_fn A callable with the signature (kernel_results, new_momentum_distribution) -> new_kernel_results where kernel_results are the results of the inner_kernel, new_momentum_distribution is a CompositeTensor or a nested collection of CompositeTensors, and new_kernel_results are a possibly-modified copy of kernel_results. The default, hmc_like_momentum_distribution_setter_fn, presumes HMC-style kernel_results, and sets the momentum_distribution only under the accepted_results field.
momentum_distribution_getter_fn A callable with the signature kernel_results -> momentum_distribution where kernel_results are the results of the inner_kernel and momentum_distribution is a CompositeTensor or a nested collection of CompositeTensors. The default, hmc_like_momentum_distribution_getter_fn, presumes HMC-style kernel_results, and gets the momentum_distribution only under the accepted_results field.
validate_args Python bool. When True kernel parameters are checked for validity. When False invalid inputs may silently render incorrect outputs.
experimental_shard_axis_names An optional structure of string names indicating how members of the state are sharded.
name Python str name prefixed to Ops created by this class. Default: 'diagonal_mass_matrix_adaptation'.

experimental_shard_axis_names The shard axis names for members of the state.
initial_running_variance

inner_kernel

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

num_estimation_steps

parameters Return dict of __init__ arguments and their values.

Methods

bootstrap_results

View source

Returns an object with the same type as returned by one_step(...)[1].

Args
init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

copy

View source

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.

Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

momentum_distribution_getter_fn

View source

momentum_distribution_setter_fn

View source

one_step

View source

Takes one step of the TransitionKernel.

Must be overridden by subclasses.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.

Returns
next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.