Attend the Women in ML Symposium on December 7

# tfp.substrates.numpy.mcmc.RandomWalkMetropolis

Stay organized with collections Save and categorize content based on your preferences.

Runs one step of the RWM algorithm with symmetric proposal.

Inherits From: `TransitionKernel`

Random Walk Metropolis is a gradient-free Markov chain Monte Carlo (MCMC) algorithm. The algorithm involves a proposal generating step `proposal_state = current_state + perturb` by a random perturbation, followed by Metropolis-Hastings accept/reject step. For more details see Section 2.1 of Roberts and Rosenthal (2004).

Current class implements RWM for normal and uniform proposals. Alternatively, the user can supply any custom proposal generating function.

The function `one_step` can update multiple chains in parallel. It assumes that all leftmost dimensions of `current_state` index independent chain states (and are therefore updated independently). The output of `target_log_prob_fn(*current_state)` should sum log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, `current_state[0, :]` could have a different target distribution from `current_state[1, :]`. These semantics are governed by `target_log_prob_fn(*current_state)`. (The number of independent chains is `tf.size(target_log_prob_fn(*current_state))`.)

#### Examples:

##### Sampling from the Standard Normal Distribution.
``````import numpy as np
from tensorflow_probability.python.internal.backend.numpy.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.numpy
tfd = tfp.distributions

dtype = np.float32

target = tfd.Normal(loc=dtype(0), scale=dtype(1))

samples = tfp.mcmc.sample_chain(
num_results=1000,
current_state=dtype(1),
kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob),
num_burnin_steps=500,
trace_fn=None,
seed=42)

sample_mean = tf.math.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
tf.math.reduce_mean(
tf.math.squared_difference(samples, sample_mean),
axis=0))

print('Estimated mean: {}'.format(sample_mean))
print('Estimated standard deviation: {}'.format(sample_std))
``````
##### Sampling from a 2-D Normal Distribution.
``````import numpy as np
from tensorflow_probability.python.internal.backend.numpy.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.numpy
tfd = tfp.distributions

dtype = np.float32
true_mean = dtype([0, 0])
true_cov = dtype([[1, 0.5],
[0.5, 1]])
num_results = 500
num_chains = 100

# Target distribution is defined through the Cholesky decomposition `L`:
L = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=L)

# Initial state of the chain
init_state = np.ones([num_chains, 2], dtype=dtype)

# Run Random Walk Metropolis with normal proposal for `num_results`
# iterations for `num_chains` independent chains:
samples = tfp.mcmc.sample_chain(
num_results=num_results,
current_state=init_state,
kernel=tfp.mcmc.RandomWalkMetropolis(target_log_prob_fn=target.log_prob),
num_burnin_steps=200,
num_steps_between_results=1,  # Thinning.
trace_fn=None,
seed=54)

sample_mean = tf.math.reduce_mean(samples, axis=0)
x = tf.squeeze(samples - sample_mean)
sample_cov = tf.matmul(tf.transpose(x, [1, 2, 0]),
tf.transpose(x, [1, 0, 2])) / num_results

mean_sample_mean = tf.math.reduce_mean(sample_mean)
mean_sample_cov = tf.math.reduce_mean(sample_cov, axis=0)
x = tf.reshape(sample_cov - mean_sample_cov, [num_chains, 2 * 2])
cov_sample_cov = tf.reshape(tf.matmul(x, x, transpose_a=True) / num_chains,
shape=[2 * 2, 2 * 2])

print('Estimated mean: {}'.format(mean_sample_mean))
print('Estimated avg covariance: {}'.format(mean_sample_cov))
print('Estimated covariance of covariance: {}'.format(cov_sample_cov))
``````
##### Sampling from the Standard Normal Distribution using Cauchy proposal.
``````import numpy as np
from tensorflow_probability.python.internal.backend.numpy.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.numpy
tfd = tfp.distributions

dtype = np.float32
num_burnin_steps = 500
num_chain_results = 1000

def cauchy_new_state_fn(scale, dtype):
cauchy = tfd.Cauchy(loc=dtype(0), scale=dtype(scale))
def _fn(state_parts, seed):
next_state_parts = []
part_seeds = tfp.random.split_seed(
seed, n=len(state_parts), salt='rwmcauchy')
for sp, ps in zip(state_parts, part_seeds):
next_state_parts.append(sp + cauchy.sample(
sample_shape=sp.shape, seed=ps))
return next_state_parts
return _fn

target = tfd.Normal(loc=dtype(0), scale=dtype(1))

samples = tfp.mcmc.sample_chain(
num_results=num_chain_results,
num_burnin_steps=num_burnin_steps,
current_state=dtype(1),
kernel=tfp.mcmc.RandomWalkMetropolis(
target.log_prob,
new_state_fn=cauchy_new_state_fn(scale=0.5, dtype=dtype)),
trace_fn=None,
seed=42)

sample_mean = tf.math.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
tf.math.reduce_mean(
tf.math.squared_difference(samples, sample_mean),
axis=0))

print('Estimated mean: {}'.format(sample_mean))
print('Estimated standard deviation: {}'.format(sample_std))
``````

`target_log_prob_fn` Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution.
`new_state_fn` Python callable which takes a list of state parts and a seed; returns a same-type `list` of `Tensor`s, each being a perturbation of the input state parts. The perturbation distribution is assumed to be a symmetric distribution centered at the input state part. Default value: `None` which is mapped to `tfp.mcmc.random_walk_normal_fn()`.
`experimental_shard_axis_names` A structure of string names indicating how members of the state are sharded.
`name` Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'rwm_kernel').

`ValueError` if there isn't one `scale` or a list with same length as `current_state`.

`experimental_shard_axis_names` The shard axis names for members of the state.
`is_calibrated` Returns `True` if Markov chain converges to specified distribution.

`TransitionKernel`s which are "uncalibrated" are often calibrated by composing them with the `tfp.mcmc.MetropolisHastings` `TransitionKernel`.

`name`

`new_state_fn`

`parameters` Return `dict` of `__init__` arguments and their values.
`target_log_prob_fn`

## Methods

### `bootstrap_results`

View source

Creates initial `previous_kernel_results` using a supplied `state`.

### `copy`

View source

Non-destructively creates a deep copy of the kernel.

Args
`**override_parameter_kwargs` Python String/value `dictionary` of initialization arguments to override with new values.

Returns
`new_kernel` `TransitionKernel` object of same type as `self`, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., `dict(self.parameters, **override_parameters_kwargs)`.

### `experimental_with_shard_axes`

View source

Returns a copy of the kernel with the provided shard axis names.

Args
`shard_axis_names` a structure of strings indicating the shard axis names for each component of this kernel's state.

Returns
A copy of the current kernel with the shard axis information.

### `one_step`

View source

Runs one iteration of Random Walk Metropolis with normal proposal.

Args
`current_state` `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
`previous_kernel_results` `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.)
`seed` PRNG seed; see `tfp.random.sanitize_seed` for details.

Returns
`next_state` Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`.
`kernel_results` `collections.namedtuple` of internal calculations used to advance the chain.

Raises
`ValueError` if there isn't one `scale` or a list with same length as `current_state`.

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]