|View source on GitHub|
Gelman and Rubin (1992)'s potential scale reduction for chain convergence.
tfp.experimental.substrates.jax.mcmc.potential_scale_reduction( chains_states, independent_chain_ndims=1, split_chains=False, validate_args=False, name=None )
N > 1 states from each of
C > 1 independent chains, the potential
scale reduction factor, commonly referred to as R-hat, measures convergence of
the chains (to the same target) by testing for equality of means.
Specifically, R-hat measures the degree to which variance (of the means)
between chains exceeds what one would expect if the chains were identically
distributed. See [Gelman and Rubin (1992)]; [Brooks and Gelman (1998)].
- The initial state of the chains should be drawn from a distribution overdispersed with respect to the target.
- If all chains converge to the target, then as
N --> infinity, R-hat --> 1. Before that, R-hat > 1 (except in pathological cases, e.g. if the chain paths were identical).
- The above holds for any number of chains
C > 1. Increasing
Cdoes improve effectiveness of the diagnostic.
- Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of course this is problem-dependent. See [Brooks and Gelman (1998)].
- R-hat only measures non-convergence of the mean. If higher moments, or other statistics are desired, a different diagnostic should be used. See [Brooks and Gelman (1998)].
Tensors representing the states of a Markov Chain at each result step. The
ithstate is assumed to have shape
[Ni, Ci1, Ci2,...,CiD] + A. Dimension
Ni > 1result steps of the Markov Chain. Dimensions
Ci1 x ... x CiDindependent chains to be tested for convergence to the same target. The remaining dimensions,
A, can have any shape (even empty).
independent_chain_ndims: Integer type
>= 1giving the number of dimensions, from
dim = 1to
dim = D, holding independent chain results to be tested for convergence.
True, divide samples from each chain into first and second halves, treating these as separate chains. This makes R-hat more robust to non-stationary chains, and is recommended in .
validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed.
Stringname to prepend to created tf. Default:
Tensor or Python
Tensors representing the R-hat statistic for
the state(s). Same
state, and shape equal to
state.shape[1 + independent_chain_ndims:].
independent_chain_ndims < 1.
Diagnosing convergence by monitoring 10 chains that each attempt to sample from a 2-variate normal.
from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax tfd = tfp.distributions target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) # Get 10 (2x) overdispersed initial states. initial_state = target.sample(10) * 2. ==> (10, 2) # Get 1000 samples from the 10 independent chains. chains_states, _ = tfp.mcmc.sample_chain( num_burnin_steps=200, num_results=1000, current_state=initial_state, kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target.log_prob, step_size=0.05, num_leapfrog_steps=20)) chains_states.shape ==> (1000, 10, 2) rhat = tfp.mcmc.diagnostic.potential_scale_reduction( chains_states, independent_chain_ndims=1) # The second dimension needed a longer burn-in. rhat.eval() ==> [1.05, 1.3]
To see why R-hat is reasonable, let
X be a random variable drawn uniformly
from the combined states (combined over all chains). Then, in the limit
N, C --> infinity, with
Var denoting expectation and variance,
R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].
Using the law of total variance, the numerator is the variance of the combined states, and the denominator is the total variance minus the variance of the the individual chain means. If the chains are all drawing from the same distribution, they will have the same mean, and thus the ratio should be one.
: Stephen P. Brooks and Andrew Gelman. General Methods for Monitoring Convergence of Iterative Simulations. Journal of Computational and Graphical Statistics, 7(4), 1998.
: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation Using Multiple Sequences. Statistical Science, 7(4):457-472, 1992. : Vehtari et al. Rank-normalization, folding, and localization: An improved Rhat for assessing convergence of MCMC.