View source on GitHub |
Runs one step of the Metropolis-Hastings algorithm.
Inherits From: TransitionKernel
tfp.mcmc.MetropolisHastings(
inner_kernel, name=None
)
The Metropolis-Hastings algorithm is a Markov chain Monte Carlo (MCMC) technique which uses a proposal distribution to eventually sample from a target distribution.
- have a
target_log_prob
field, - optionally have a
log_acceptance_correction
field, and, - have only fields which are
Tensor
-valued.
The Metropolis-Hastings log acceptance-probability is computed as:
log_accept_ratio = (current_kernel_results.target_log_prob
- previous_kernel_results.target_log_prob
+ current_kernel_results.log_acceptance_correction)
If current_kernel_results.log_acceptance_correction
does not exist, it is
presumed 0.
(i.e., that the proposal distribution is symmetric).
The most common use-case for log_acceptance_correction
is in the
Metropolis-Hastings algorithm, i.e.,
accept_prob(x' | x) = p(x') / p(x) (g(x|x') / g(x'|x))
where,
p represents the target distribution,
g represents the proposal (conditional) distribution,
x' is the proposed state, and,
x is current state
The log of the parenthetical term is the log_acceptance_correction
.
The log_acceptance_correction
may not necessarily correspond to the ratio of
proposal distributions, e.g, log_acceptance_correction
has a different
interpretation in Hamiltonian Monte Carlo.
Examples
import tensorflow_probability as tfp
hmc = tfp.mcmc.MetropolisHastings(
tfp.mcmc.UncalibratedHamiltonianMonteCarlo(
target_log_prob_fn=lambda x: -x - x**2,
step_size=0.1,
num_leapfrog_steps=3))
# ==> functionally equivalent to:
# hmc = tfp.mcmc.HamiltonianMonteCarlo(
# target_log_prob_fn=lambda x: -x - x**2,
# step_size=0.1,
# num_leapfrog_steps=3)
Attributes | |
---|---|
experimental_shard_axis_names
|
The shard axis names for members of the state. |
inner_kernel
|
|
is_calibrated
|
Returns True if Markov chain converges to specified distribution.
|
name
|
|
parameters
|
Return dict of __init__ arguments and their values.
|
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step
.
Args | |
---|---|
init_state
|
Tensor or Python list of Tensor s representing the
initial state(s) of the Markov chain(s).
|
Returns | |
---|---|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
Raises | |
---|---|
ValueError
|
if inner_kernel results doesn't contain the member
"target_log_prob".
|
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs
|
Python String/value dictionary of
initialization arguments to override with new values.
|
Returns | |
---|---|
new_kernel
|
TransitionKernel object of same type as self ,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,
dict(self.parameters, **override_parameters_kwargs) .
|
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names
|
a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Args | |
---|---|
current_state
|
Tensor or Python list of Tensor s representing the
current state(s) of the Markov chain(s).
|
previous_kernel_results
|
A (possibly nested) tuple , namedtuple or
list of Tensor s representing internal calculations made within the
previous call to this function (or as returned by bootstrap_results ).
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
next_state
|
Tensor or Python list of Tensor s representing the
next state(s) of the Markov chain(s).
|
kernel_results
|
A (possibly nested) tuple , namedtuple or list of
Tensor s representing internal calculations made within this function.
|
Raises | |
---|---|
ValueError
|
if inner_kernel results doesn't contain the member
"target_log_prob".
|