Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tfp.experimental.substrates.jax.mcmc.ReplicaExchangeMC

View source on GitHub

Runs one step of the Replica Exchange Monte Carlo.

Inherits From: TransitionKernel

tfp.experimental.substrates.jax.mcmc.ReplicaExchangeMC(
    target_log_prob_fn, inverse_temperatures, make_kernel_fn,
    swap_proposal_fn=default_swap_proposal_fn(1.0), seed=None, validate_args=False,
    name=None
)

Replica Exchange Monte Carlo is a Markov chain Monte Carlo (MCMC) algorithm that is also known as Parallel Tempering. This algorithm takes multiple samples (from tempered distributions) in parallel, then swaps these samples according to the Metropolis-Hastings criterion. See also the review paper [1].

The K replicas are parameterized in terms of inverse_temperature's, (beta[0], beta[1], ..., beta[K-1]). If the target distribution has probability density p(x), the kth replica has density p(x)**beta_k.

Typically beta[0] = 1.0, and 1.0 > beta[1] > beta[2] > ... > 0.0. Trying geometrically decaying beta is good starting point.

  • beta[0] == 1 ==> First replicas samples from the target density, p.
  • beta[k] < 1, for k = 1, ..., K-1 ==> Other replicas sample from "flattened" versions of p (peak is less high, valley less low). These distributions are somewhat closer to a uniform on the support of p.

By default, samples from adjacent replicas i, i + 1 are used as proposals for each other in a Metropolis step. This allows the lower beta samples, which explore less dense areas of p, to eventually swap state with the beta == 1 chain, allowing it to explore these new regions.

Samples from replica 0 are returned, and the others are discarded.

Examples

Sampling from the Standard Normal Distribution.
import numpy as np
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
tfd = tfp.distributions

dtype = tf.float32

target = tfd.Normal(loc=dtype(0), scale=dtype(1))

# Geometric decay is a good rule of thumb.
inverse_temperatures = 0.5**tf.range(4, dtype=dtype)

# If everything was Normal, step_size should be ~ sqrt(temperature).
step_size = 0.5 / tf.sqrt(inverse_temperatures)

def make_kernel_fn(target_log_prob_fn, seed):
  return tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      seed=seed, step_size=step_size, num_leapfrog_steps=3)

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=target.log_prob,
    inverse_temperatures=inverse_temperatures,
    make_kernel_fn=make_kernel_fn)

def trace_swaps(unused_state, results):
  return (results.is_swap_proposed_adjacent,
          results.is_swap_accepted_adjacent)

samples, is_swap_proposed_adjacent, is_swap_accepted_adjacent = (
    tfp.mcmc.sample_chain(
        num_results=1000,
        current_state=1.0,
        kernel=remc,
        num_burnin_steps=500,
        trace_fn=trace_swaps)
)

# conditional_swap_prob[k] = P[ExchangeAccepted | ExchangeProposed],
# for the swap between replicas k and k+1.
conditional_swap_prob = (
    tf.reduce_sum(tf.cast(is_swap_accepted_adjacent, tf.float32), axis=0)
    /
    tf.reduce_sum(tf.cast(is_swap_proposed_adjacent, tf.float32), axis=0))
Sampling from a 2-D Mixture Normal Distribution.
import numpy as np
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
import matplotlib.pyplot as plt
tfd = tfp.distributions

dtype = tf.float32

target = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
    components_distribution=tfd.MultivariateNormalDiag(
        loc=[[-1., -1], [1., 1.]],
        scale_identity_multiplier=[0.1, 0.1]))

inverse_temperatures = 0.5**tf.range(4, dtype=dtype)

# step_size must broadcast with all batch and event dimensions of target.
# Here, this means it must broadcast with:
#  [len(inverse_temperatures)] + target.event_shape
step_size = 0.5 / tf.reshape(tf.sqrt(inverse_temperatures), shape=(4, 1))

def make_kernel_fn(target_log_prob_fn, seed):
  return tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn,
      seed=seed, step_size=step_size, num_leapfrog_steps=3)

remc = tfp.mcmc.ReplicaExchangeMC(
    target_log_prob_fn=target.log_prob,
    inverse_temperatures=inverse_temperatures,
    make_kernel_fn=make_kernel_fn)

samples = tfp.mcmc.sample_chain(
    num_results=1000,
    # Start near the [1, 1] mode.  Standard HMC would get stuck there.
    current_state=tf.ones(2, dtype=dtype),
    kernel=remc,
    trace_fn=None,
    num_burnin_steps=500)

plt.figure(figsize=(8, 8))
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.plot(samples_[:, 0], samples_[:, 1], '.')
plt.show()

References

[1]: David J. Earl, Michael W. Deem Parallel Tempering: Theory, Applications, and New Perspectives https://arxiv.org/abs/physics/0508111

Args:

  • target_log_prob_fn: Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
  • inverse_temperatures: Tensor of inverse temperatures to temper each replica. The leftmost dimension is the num_replica and the second dimension through the rightmost can provide different temperature to different batch members, doing a left-justified broadcast.
  • make_kernel_fn: Python callable which takes target_log_prob_fn and seed args and returns a TransitionKernel instance.
  • swap_proposal_fn: Python callable which take a number of replicas, and returns swaps, a shape [num_replica] + batch_shape Tensor, where axis 0 indexes a permutation of {0,..., num_replica-1}, designating replicas to swap.
  • seed: Python integer to seed the random number generator. Default value: None (i.e., no seed).
  • validate_args: Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs.
  • name: Python str name prefixed to Ops created by this function. Default value: None (i.e., "remc_kernel").

Attributes:

  • inverse_temperatures
  • is_calibrated: Returns True if Markov chain converges to specified distribution.

    TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

  • make_kernel_fn

  • name

  • parameters: Return dict of __init__ arguments and their values.

  • seed

  • swap_proposal_fn

  • target_log_prob_fn

  • validate_args

Raises:

  • ValueError: inverse_temperatures doesn't have statically known 1D shape.

Methods

bootstrap_results

View source

bootstrap_results(
    init_state
)

Returns an object with the same type as returned by one_step.

Args:

  • init_state: Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).

Returns:

  • kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function. This inculdes replica states.

num_replica

View source

num_replica()

Integer (Tensor) number of replicas being tracked.

one_step

View source

one_step(
    current_state, previous_kernel_results
)

Takes one step of the TransitionKernel.

Args:

  • current_state: Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
  • previous_kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).

Returns:

  • next_state: Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
  • kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function. This inculdes replica states.