TensorFlow 2.0 Beta is available Learn more

tfp.mcmc.TransformedTransitionKernel

Class TransformedTransitionKernel

TransformedTransitionKernel applies a bijector to the MCMC's state space.

Inherits From: TransitionKernel

Defined in python/mcmc/transformed_kernel.py.

The TransformedTransitionKernel TransitionKernel enables fitting a Bijector which serves to decorrelate the Markov chain Monte Carlo (MCMC) event dimensions thus making the chain mix faster. This is particularly useful when the geometry of the target distribution is unfavorable. In such cases it may take many evaluations of the target_log_prob_fn for the chain to mix between faraway states.

The idea of training an affine function to decorrelate chain event dims was presented in [Parno and Marzouk (2014)][1]. Used in conjunction with the HamiltonianMonteCarlo TransitionKernel, the [Parno and Marzouk (2014)][1] idea is an instance of Riemannian manifold HMC [(Girolami and Calderhead, 2011)][2].

The TransformedTransitionKernel enables arbitrary bijective transformations of arbitrary TransitionKernels, e.g., one could use bijectors tfp.distributions.bijectors.Affine, tfp.distributions.bijectors.RealNVP, etc. with transition kernels tfp.mcmc.HamiltonianMonteCarlo, tfp.mcmc.RandomWalkMetropolis, etc.

Examples

RealNVP + HamiltonianMonteCarlo
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

def make_likelihood(true_variances):
  return tfd.MultivariateNormalDiag(
      scale_diag=tf.sqrt(true_variances))

dims = 10
dtype = np.float32
true_variances = tf.linspace(dtype(1), dtype(3), dims)
likelihood = make_likelihood(true_variances)

realnvp_hmc = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=likelihood.log_prob,
      step_size=0.5,
      num_leapfrog_steps=2),
    bijector=tfb.RealNVP(
      num_masked=2,
      shift_and_log_scale_fn=tfb.real_nvp_default_template(
          hidden_layers=[512, 512])))

states, kernel_results = tfp.mcmc.sample_chain(
    num_results=1000,
    current_state=tf.zeros(dims),
    kernel=realnvp_hmc,
    num_burnin_steps=500)

# Compute sample stats.
sample_mean = tf.reduce_mean(states, axis=0)
sample_var = tf.reduce_mean(
    tf.squared_difference(states, sample_mean),
    axis=0)

References

[1]: Matthew Parno and Youssef Marzouk. Transport map accelerated Markov chain Monte Carlo. arXiv preprint arXiv:1412.5492, 2014. https://arxiv.org/abs/1412.5492

[2]: Mark Girolami and Ben Calderhead. Riemann manifold langevin and hamiltonian monte carlo methods. In Journal of the Royal Statistical Society, 2011. https://doi.org/10.1111/j.1467-9868.2010.00765.x

__init__

__init__(
    inner_kernel,
    bijector,
    name=None
)

Instantiates this object.

Args:

  • inner_kernel: TransitionKernel-like object which has a target_log_prob_fn argument.
  • bijector: tfp.distributions.Bijector or list of tfp.distributions.Bijectors. These bijectors use forward to map the inner_kernel state space to the state expected by inner_kernel.target_log_prob_fn.
  • name: Python str name prefixed to Ops created by this function. Default value: None (i.e., "transformed_kernel").

Returns:

  • transformed_kernel: Instance of TransitionKernel which copies the input transition kernel then modifies its target_log_prob_fn by applying the provided bijector(s).

Properties

bijector

inner_kernel

is_calibrated

Returns True if Markov chain converges to specified distribution.

TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.

name

parameters

Return dict of __init__ arguments and their values.

Methods

bootstrap_results

bootstrap_results(
    init_state=None,
    transformed_init_state=None
)

Returns an object with the same type as returned by one_step.

Unlike other TransitionKernels, TransformedTransitionKernel.bootstrap_results has the option of initializing the TransformedTransitionKernelResults from either an initial state, eg, requiring computing bijector.inverse(init_state), or directly from transformed_init_state, i.e., a Tensor or list of Tensors which is interpretted as the bijector.inverse transformed state.

Args:

  • init_state: Tensor or Python list of Tensors representing the a state(s) of the Markov chain(s). Must specify init_state or transformed_init_state but not both.
  • transformed_init_state: Tensor or Python list of Tensors representing the a state(s) of the Markov chain(s). Must specify init_state or transformed_init_state but not both.

Returns:

  • kernel_results: A (possibly nested) tuple, namedtuple or list of Tensors representing internal calculations made within this function.

Raises:

  • ValueError: if inner_kernel results doesn't contain the member "target_log_prob".

Examples

To use transformed_init_state in context of tfp.mcmc.sample_chain, you need to explicitly pass the previous_kernel_results, e.g.,

transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
init_state = ...        # Doesnt matter.
transformed_init_state = ... # Does matter.
results, _ = tfp.mcmc.sample_chain(
    num_results=...,
    current_state=init_state,
    previous_kernel_results=transformed_kernel.bootstrap_results(
        transformed_init_state=transformed_init_state),
    kernel=transformed_kernel)

one_step

one_step(
    current_state,
    previous_kernel_results
)

Runs one iteration of the Transformed Kernel.

Args:

  • current_state: Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s), after application of bijector.forward. The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)). The inner_kernel.one_step does not actually use current_state, rather it takes as input previous_kernel_results.transformed_state (because TransformedTransitionKernel creates a copy of the input inner_kernel with a modified target_log_prob_fn which internally applies the bijector.forward).
  • previous_kernel_results: collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.)

Returns:

  • next_state: Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
  • kernel_results: collections.namedtuple of internal calculations used to advance the chain.