tfp.mcmc.ReplicaExchangeMC

View source on GitHub

Runs one step of the Replica Exchange Monte Carlo.

Inherits From: TransitionKernel

Replica Exchange Monte Carlo is a Markov chain Monte Carlo (MCMC) algorithm that is also known as Parallel Tempering. This algorithm takes multiple samples (from tempered distributions) in parallel, then swaps these samples according to the Metropolis-Hastings criterion. See also the review paper [1].

The K replicas are parameterized in terms of inverse_temperature's, (beta[0], beta[1], ..., beta[K-1]). If the target distribution has probability density p(x), the kth replica has density p(x)**beta_k.

Typically beta[0] = 1.0, and 1.0 > beta[1] > beta[2] > ... > 0.0. Trying geometrically decaying beta is good starting point.

  • beta[0] == 1 ==> First replicas samples from the target density, p.
  • beta[k] < 1, for k = 1, ..., K-1 ==> Other replicas sample from "flattened" versions of p (peak is less high, valley less low). These distributions are somewhat closer to a uniform on the support of p.

By default, samples from adjacent replicas i, i + 1 are used as proposals for each other in a Metropolis step. This allows the lower beta samples, which explore less dense areas of p, to eventually swap state with the beta == 1 chain, allowing it to explore these new regions.

Samples from replica 0 are returned, and the others are discarded.

Examples

Sampling from the Standard Normal Distribution.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

dtype = tf.float32

target = tfd.Normal(loc=dtype(0), scale=dtype(1))

# Geometric decay is a good rule of thumb.
inverse_temperatures = 0.5**tf.range(4, dtype=dtype)

# If everything was Normal, step_size should be ~ sqrt(temperature).
step_size = 0.5 / tf.sqrt(inverse_temperatures)

def make_kernel_fn(target_log_prob_fn, seed):
  return tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      seed=seed, step_size=step_size, num_leapfrog_steps=3)

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=target.log_prob,
    inverse_temperatures=inverse_temperatures,
    make_kernel_fn=make_kernel_fn)

def trace_swaps(unused_state, results):
  return (results.is_swap_proposed_adjacent,
          results.is_swap_accepted_adjacent)

samples, is_swap_proposed_adjacent, is_swap_accepted_adjacent = (
    tfp.mcmc.sample_chain(
        num_results=1000,
        current_state=1.0,
        kernel=remc,
        num_burnin_steps=500,
        trace_fn=trace_swaps)
)

# conditional_swap_prob[k] = P[ExchangeAccepted | ExchangeProposed],
# for the swap between replicas k and k+1.
conditional_swap_prob = (
    tf.reduce_sum(tf.cast(is_swap_accepted_adjacent, tf.float32), axis=0)
    /
    tf.reduce_sum(tf.cast(is_swap_proposed_adjacent, tf.float32), axis=0))
Sampling from a 2-D Mixture Normal Distribution.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
tfd = tfp.distributions

dtype = tf.float32

target = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
    components_distribution=tfd.MultivariateNormalDiag(
        loc=[[-1., -1], [1., 1.]],
        scale_identity_multiplier=[0.1, 0.1]))

inverse_temperatures = 0.5**tf.range(4, dtype=dtype)

# step_size must broadcast with all batch and event dimensions of target.
# Here, this means it must broadcast with:
#  [len(inverse_temperatures)] + target.event_shape
step_size = 0.5 / tf.reshape(tf.sqrt(inverse_temperatures), shape=(4, 1))

def make_kernel_fn(target_log_prob_fn, seed):
  return tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      seed=seed, step_size=step_size, num_leapfrog_steps=3)

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=target.log_prob,
    inverse_temperatures=inverse_temperatures,
    make_kernel_fn=make_kernel_fn)

samples = tfp.mcmc.sample_chain(
    num_results=1000,
    # Start near the [1, 1] mode.  Standard HMC would get stuck there.
    current_state=tf.ones(2, dtype=dtype),
    kernel=remc,
    trace_fn=None,
    num_burnin_steps=500)

plt.figure(figsize=(8, 8))
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.plot(samples_[:, 0], samples_[:, 1], '.')
plt.show()

References

[1]: David J. Earl, Michael W. Deem Parallel Tempering: Theory, Applications, and New Perspectives https://arxiv.org/abs/physics/0508111

target_log_prob_fn Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
inverse_temperatures Tensor of inverse temperatures to temper each replica. The leftmost dimension is the num_replica and the second dimension through the rightmost can provide different temperature to different batch members, doing a left-justified broadcast.
make_kernel_fn Python callable which takes target_log_prob_fn and seed args and returns a TransitionKernel instance.
swap_proposal_fn Python callable which take a number of replicas, and returns swaps, a shape [num_replica] + batch_shape Tensor, where axis 0 indexes a permutation of {0,..., num_replica-1}, designating replicas to swap.
seed Python integer to seed the random number generator. Default value: None (i.e., no seed).
validate_args Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "remc_kernel").

ValueError inverse_temperatures doesn't have statically known 1D shape.

inverse_temperatures

is_calibrated Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

make_kernel_fn

name

parameters Return dict of __init__ arguments and their values.
seed

swap_proposal_fn

target_log_prob_fn

validate_args

Methods

bootstrap_results

View source

Returns an object with the same type as returned by one_step.

Args
init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function. This inculdes replica states.

num_replica

View source

Integer (Tensor) number of replicas being tracked.

one_step

View source

Takes one step of the TransitionKernel.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).

Returns
next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function. This inculdes replica states.