View source on GitHub |
Runs one step of Metropolis-adjusted Langevin algorithm.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.MetropolisAdjustedLangevinAlgorithm(
target_log_prob_fn,
step_size,
volatility_fn=None,
parallel_iterations=10,
experimental_shard_axis_names=None,
name=None
)
Metropolis-adjusted Langevin algorithm (MALA) is a Markov chain Monte Carlo
(MCMC) algorithm that takes a step of a discretised Langevin diffusion as a
proposal. This class implements one step of MALA using Euler-Maruyama method
for a given current_state
and diagonal preconditioning volatility
matrix.
Mathematical details and derivations can be found in
[Roberts and Rosenthal (1998)][1] and [Xifara et al. (2013)][2].
See UncalibratedLangevin
class description below for details on the proposal
generating step of the algorithm.
The one_step
function can update multiple chains in parallel. It assumes
that all leftmost dimensions of current_state
index independent chain states
(and are therefore updated independently). The output of
target_log_prob_fn(*current_state)
should reduce log-probabilities across
all event dimensions. Slices along the rightmost dimensions may have different
target distributions; for example, current_state[0, :]
could have a
different target distribution from current_state[1, :]
. These semantics are
governed by target_log_prob_fn(*current_state)
. (The number of independent
chains is tf.size(target_log_prob_fn(*current_state))
.)
Examples:
Simple chain with warm-up.
In this example we sample from a standard univariate normal
distribution using MALA with step_size
equal to 0.75.
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import numpy as np
import matplotlib.pyplot as plt
tfd = tfp.distributions
dtype = np.float32
# Target distribution is Standard Univariate Normal
target = tfd.Normal(loc=dtype(0), scale=dtype(1))
def target_log_prob(x):
return target.log_prob(x)
# Define MALA sampler with `step_size` equal to 0.75
samples = tfp.mcmc.sample_chain(
num_results=1000,
current_state=dtype(1),
kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
target_log_prob_fn=target_log_prob,
step_size=0.75),
num_burnin_steps=500,
trace_fn=None,
seed=42)
sample_mean = tf.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
tf.reduce_mean(
tf.math.squared_difference(samples, sample_mean),
axis=0))
print('sample mean', sample_mean)
print('sample standard deviation', sample_std)
plt.title('Traceplot')
plt.plot(samples.numpy(), 'b')
plt.xlabel('Iteration')
plt.ylabel('Position')
plt.show()
Sample from a 3-D Multivariate Normal distribution.
In this example we also consider a non-constant volatility function.
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import numpy as np
dtype = np.float32
true_mean = dtype([0, 0, 0])
true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]])
num_results = 500
num_chains = 500
# Target distribution is defined through the Cholesky decomposition
chol = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)
# Here we define the volatility function to be non-constant
def volatility_fn(x):
# Stack the input tensors together
return 1. / (0.5 + 0.1 * tf.math.abs(x))
# Initial state of the chain
init_state = np.ones([num_chains, 3], dtype=dtype)
# Run MALA with normal proposal for `num_results` iterations for
# `num_chains` independent chains:
states = tfp.mcmc.sample_chain(
num_results=num_results,
current_state=init_state,
kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
target_log_prob_fn=target.log_prob,
step_size=.1,
volatility_fn=volatility_fn),
num_burnin_steps=200,
num_steps_between_results=1,
trace_fn=None,
seed=42)
sample_mean = tf.reduce_mean(states, axis=[0, 1])
x = (states - sample_mean)[..., tf.newaxis]
sample_cov = tf.reduce_mean(
tf.matmul(x, tf.transpose(x, [0, 1, 3, 2])), [0, 1])
print('sample mean', sample_mean.numpy())
print('sample covariance matrix', sample_cov.numpy())
References
[1]: Gareth Roberts and Jeffrey Rosenthal. Optimal Scaling of Discrete Approximations to Langevin Diffusions. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 60: 255-268, 1998. https://doi.org/10.1111/1467-9868.00123
[2]: T. Xifara et al. Langevin diffusions and the Metropolis-adjusted Langevin algorithm. arXiv preprint arXiv:1309.2983, 2013. https://arxiv.org/abs/1309.2983
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state .
|
TypeError
|
if volatility_fn is not callable.
|
Attributes | |
---|---|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
name
|
|
parallel_iterations
|
|
parameters
|
Return dict of __init__ arguments and their values.
|
step_size
|
|
target_log_prob_fn
|
|
volatility_fn
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Creates initial previous_kernel_results
using a supplied state
.
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Runs one iteration of MALA.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s). The first r dimensions index
independent chains, r = tf.rank(target_log_prob_fn(*current_state)) .
|
previous_kernel_results
|
collections.namedtuple containing Tensor s
representing values from previous calls to this function (or from the
bootstrap_results function.)
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape as current_state .
|
kernel_results
|
collections.namedtuple of internal calculations used to
advance the chain.
|
Raises | |
---|---|
ValueError
|
if there isn't one step_size or a list with same length as
current_state or diffusion_drift .
|