TensorFlow 2.0 Beta is available Learn more

tfp.mcmc.ReplicaExchangeMC

Class ReplicaExchangeMC

Runs one step of the Replica Exchange Monte Carlo.

Inherits From: TransitionKernel

Defined in python/mcmc/replica_exchange_mc.py.

Replica Exchange Monte Carlo is a Markov chain Monte Carlo (MCMC) algorithm that is also known as Parallel Tempering. This algorithm performs multiple sampling with different temperatures in parallel, and exchanges those samplings according to the Metropolis-Hastings criterion.

The K replicas are parameterized in terms of inverse_temperature's, (beta[0], beta[1], ..., beta[K-1]). If the target distribution has probability density p(x), the kth replica has density p(x)**beta_k.

Typically beta[0] = 1.0, and 1.0 > beta[1] > beta[2] > ... > 0.0.

  • beta[0] == 1 ==> First replicas samples from the target density, p.
  • beta[k] < 1, for k = 1, ..., K-1 ==> Other replicas sample from "flattened" versions of p (peak is less high, valley less low). These distributions are somewhat closer to a uniform on the support of p.

Samples from adjacent replicas i, i + 1 are used as proposals for each other in a Metropolis step. This allows the lower beta samples, which explore less dense areas of p, to occasionally be used to help the beta == 1 chain explore new regions of the support.

Samples from replica 0 are returned, and the others are discarded.

Examples

Sampling from the Standard Normal Distribution.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

dtype = np.float32

target = tfd.Normal(loc=dtype(0), scale=dtype(1))

def make_kernel_fn(target_log_prob_fn, seed):
  return tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      seed=seed, step_size=1.0, num_leapfrog_steps=3)

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=target.log_prob,
    inverse_temperatures=[1., 0.3, 0.1, 0.03],
    make_kernel_fn=make_kernel_fn,
    seed=42)

samples, _ = tfp.mcmc.sample_chain(
    num_results=1000,
    current_state=dtype(1),
    kernel=remc,
    num_burnin_steps=500,
    parallel_iterations=1)  # For determinism.

sample_mean = tf.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
    tf.reduce_mean(tf.squared_difference(samples, sample_mean),
                   axis=0))
with tf.Session() as sess:
  [sample_mean_, sample_std_] = sess.run([sample_mean, sample_std])

print('Estimated mean: {}'.format(sample_mean_))
print('Estimated standard deviation: {}'.format(sample_std_))
Sampling from a 2-D Mixture Normal Distribution.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
tfd = tfp.distributions

dtype = np.float32

target = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
    components_distribution=tfd.MultivariateNormalDiag(
        loc=[[-1., -1], [1., 1.]],
        scale_identity_multiplier=[0.1, 0.1]))

def make_kernel_fn(target_log_prob_fn, seed):
  return tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      seed=seed, step_size=0.3, num_leapfrog_steps=3)

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=target.log_prob,
    inverse_temperatures=[1., 0.3, 0.1, 0.03, 0.01],
    make_kernel_fn=make_kernel_fn,
    seed=42)

samples, _ = tfp.mcmc.sample_chain(
    num_results=1000,
    # Start near the [1, 1] mode.  Standard HMC would get stuck there.
    current_state=np.ones(2, dtype=dtype),
    kernel=remc,
    num_burnin_steps=500,
    parallel_iterations=1)  # For determinism.

with tf.Session() as sess:
  samples_ = sess.run(samples)

plt.figure(figsize=(8, 8))
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.plot(samples_[:, 0], samples_[:, 1], '.')
plt.show()

__init__

__init__(
    target_log_prob_fn,
    inverse_temperatures,
    make_kernel_fn,
    exchange_proposed_fn=default_exchange_proposed_fn(1.0),
    seed=None,
    name=None
)

Instantiates this object.

Args:

  • target_log_prob_fn: Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
  • inverse_temperatures: 1D Tensor of inverse temperatures to perform samplings with each replica. Must have statically knownshape.inverse_temperatures[0]` produces the states returned by samplers, and is typically == 1.
  • make_kernel_fn: Python callable which takes target_log_prob_fn and seed args and returns a TransitionKernel instance.
  • exchange_proposed_fn: Python callable which take a number of replicas, and return combinations of replicas for exchange.
  • seed: Python integer to seed the random number generator. Default value: None (i.e., no seed).
  • name: Python str name prefixed to Ops created by this function. Default value: None (i.e., "remc_kernel").

Raises:

  • ValueError: inverse_temperatures doesn't have statically known 1D shape.

Properties

exchange_proposed_fn

inverse_temperatures

is_calibrated

Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

num_replica

parameters

Return dict of __init__ arguments and their values.

seed

target_log_prob_fn

Methods

bootstrap_results

bootstrap_results(init_state)

Returns an object with the same type as returned by one_step.

Args:

  • init_state: Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns:

  • kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function. This inculdes replica states.

one_step

one_step(
    current_state,
    previous_kernel_results
)

Takes one step of the TransitionKernel.

Args:

  • current_state: Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
  • previous_kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).

Returns:

  • next_state: Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
  • kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function. This inculdes replica states.