View source on GitHub |
Runs one step of the Replica Exchange Monte Carlo.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.ReplicaExchangeMC(
target_log_prob_fn,
inverse_temperatures,
make_kernel_fn,
swap_proposal_fn=default_swap_proposal_fn(1.0),
state_includes_replicas=False,
untempered_log_prob_fn=None,
tempered_log_prob_fn=None,
validate_args=False,
name=None
)
Replica Exchange Monte Carlo is a Markov chain Monte Carlo (MCMC) algorithm that is also known as Parallel Tempering. This algorithm takes multiple samples (from tempered distributions) in parallel, then swaps these samples according to the Metropolis-Hastings criterion. See also the review paper [1].
The K
replicas are parameterized in terms of inverse_temperature
's,
(beta[0], beta[1], ..., beta[K-1])
. If the user provides
target_log_prob_fn
, then the kth
replica samples from density p_k(x)
,
with log(p_k(x)) = beta_k * target_log_prob(x)
.
In this case, geometrically decaying beta
often works well. That is, with
R < 1
, we recommend trying beta[k] = R^k
so that
1.0 = beta[0] > beta[1] > ... > 0
. See [2].
The user can also provide two functions, tempered_log_prob_fn
and
untempered_log_prob_fn
. In this case, the kth
replica samples from
density p_k(x)
with
log(p_k(x)) = beta_k * tempered_log_prob_fn(x) + untempered_log_prob_fn(x)
.
In this case, beta
may be zero, and one often sets beta[-1]
to zero.
This means the last replica samples using untempered_log_prob_fn
.
In the Bayesian setup, untempered_log_prob_fn
will often be the log prior,
and tempered_log_prob_fn
the likelihood.
In all cases,
beta[0] == 1
==> First replica samples from the target density.beta[k] < 1
, fork = 1, ..., K-1
==> Other replicas sample from "tempered" versions of target (peak is less high, valley less low). These distributions should allow easier exploration of separated modes.
By default, samples from adjacent replicas i
, i + 1
are used as proposals
for each other in a Metropolis step. This allows the lower beta
samples,
which explore less dense areas of p
, to eventually swap state with the
beta == 1
chain, allowing it to explore these new regions.
Samples from replica 0 are returned, and the others are discarded, unless
state_includes_replicas
.
Examples
Sampling from the Standard Normal Distribution.
import numpy as np
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
dtype = np.float32
target = tfd.Normal(loc=dtype(0), scale=dtype(1))
# Geometric decay is a good rule of thumb.
inverse_temperatures = 0.5**tf.range(4, dtype=dtype)
# If everything was Normal, step_size should be ~ sqrt(temperature).
step_size = 1.5 / tf.sqrt(inverse_temperatures)
def make_kernel_fn(target_log_prob_fn):
return tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=step_size, num_leapfrog_steps=3)
remc = tfp.mcmc.ReplicaExchangeMC(
target_log_prob_fn=target.log_prob,
inverse_temperatures=inverse_temperatures,
make_kernel_fn=make_kernel_fn)
def trace_swaps(unused_state, results):
return (results.is_swap_proposed_adjacent,
results.is_swap_accepted_adjacent)
samples, (is_swap_proposed_adjacent, is_swap_accepted_adjacent) = (
tfp.mcmc.sample_chain(
num_results=1000,
current_state=1.0,
kernel=remc,
num_burnin_steps=500,
trace_fn=trace_swaps)
)
# conditional_swap_prob[k] = P[ExchangeAccepted | ExchangeProposed],
# for the swap between replicas k and k+1.
conditional_swap_prob = (
tf.reduce_sum(tf.cast(is_swap_accepted_adjacent, tf.float32), axis=0)
/
tf.reduce_sum(tf.cast(is_swap_proposed_adjacent, tf.float32), axis=0))
Sampling from a 2-D Mixture Normal Distribution.
import numpy as np
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
import matplotlib.pyplot as plt
tfd = tfp.distributions
dtype = np.float32
target = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
components_distribution=tfd.MultivariateNormalDiag(
loc=[[-1., -1], [1., 1.]],
scale_diag=0.1*tf.ones([2, 2])))
inverse_temperatures = 0.2**tf.range(4, dtype=dtype)
# step_size must broadcast with all batch and event dimensions of target.
# Here, this means it must broadcast with:
# [len(inverse_temperatures)] + target.event_shape
step_size = 0.075 / tf.reshape(tf.sqrt(inverse_temperatures), shape=(4, 1))
def make_kernel_fn(target_log_prob_fn):
return tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=step_size, num_leapfrog_steps=3)
remc = tfp.mcmc.ReplicaExchangeMC(
target_log_prob_fn=target.log_prob,
inverse_temperatures=inverse_temperatures,
make_kernel_fn=make_kernel_fn)
samples = tfp.mcmc.sample_chain(
num_results=1000,
# Start near the [1, 1] mode. Standard HMC would get stuck there.
current_state=tf.ones(2, dtype=dtype),
kernel=remc,
trace_fn=None,
num_burnin_steps=500)
plt.figure(figsize=(8, 8))
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.plot(samples[:, 0], samples[:, 1], '.')
plt.show()
References
[1]: David J. Earl, Michael W. Deem Parallel Tempering: Theory, Applications, and New Perspectives https://arxiv.org/abs/physics/0508111 [2]: David A. Kofke On the acceptance probability of replica-exchange Monte Carlo trials. J. of Chem. Phys. Vol. 117 No. 5.
Args | |
---|---|
target_log_prob_fn
|
Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns its
(possibly unnormalized) log-density under the target distribution.
Must be None if the pair tempered/untempered_log_prob_fn is provided
|
inverse_temperatures
|
Tensor of inverse temperatures to temper each
replica. The leftmost dimension is the num_replica and the
second dimension through the rightmost can provide different temperature
to different batch members, doing a left-justified broadcast.
|
make_kernel_fn
|
Python callable which takes a target_log_prob_fn
arg and returns a tfp.mcmc.TransitionKernel instance.
|
swap_proposal_fn
|
Python callable which take a number of replicas, and
returns swaps , a shape [num_replica] + batch_shape Tensor , where
axis 0 indexes a permutation of {0,..., num_replica-1} , designating
replicas to swap.
|
state_includes_replicas
|
Boolean indicating whether the leftmost dimension
of each state sample should index replicas. If True , the leftmost
dimension of the current_state kwarg to tfp.mcmc.sample_chain will
be interpreted as indexing replicas.
|
untempered_log_prob_fn
|
Python callable which takes an argument like
current_state (or *current_state if it's a list) and returns its
(possibly unnormalized) log-density under the target distribution.
Must be None if target_log_prob_fn is provided.
|
tempered_log_prob_fn
|
Optional Python callable with same signature as
untempered_log_prob_fn . Provide this arg if and only if
untempered_log_prob_fn is provided.
|
validate_args
|
Python bool , default False . When True distribution
parameters are checked for validity despite possibly degrading runtime
performance. When False invalid inputs may silently render incorrect
outputs.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., "remc_kernel").
|
Raises | |
---|---|
ValueError
|
inverse_temperatures doesn't have statically known 1D shape.
|
ValueError
|
If wrong combination of log prob functions are provided. |
Attributes | |
---|---|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
inverse_temperatures
|
|
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
make_kernel_fn
|
|
name
|
|
parameters
|
Return dict of __init__ arguments and their values.
|
swap_proposal_fn
|
|
target_log_prob_fn
|
|
tempered_log_prob_fn
|
|
untempered_log_prob_fn
|
|
validate_args
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
This inculdes replica states.
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axes
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
num_replica
num_replica()
Integer (Tensor
) number of replicas being tracked.
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).
|
previous_kernel_results
|
A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).
|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
This inculdes replica states.
|