View source on GitHub |
Estimate a lower bound on effective sample size for each independent chain.
tfp.substrates.jax.mcmc.effective_sample_size(
states,
filter_threshold=0.0,
filter_beyond_lag=None,
filter_beyond_positive_pairs=False,
cross_chain_dims=None,
validate_args=False,
name=None
)
Roughly speaking, "effective sample size" (ESS) is the size of an iid sample
with the same variance as state
.
More precisely, given a stationary sequence of possibly correlated random
variables X_1, X_2, ..., X_N
, identically distributed, ESS is the
number such that
Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.
If the sequence is uncorrelated, ESS = N
. If the sequence is positively
auto-correlated, ESS
will be less than N
. If there are negative
correlations, then ESS
can exceed N
.
Some math shows that, with R_k
the auto-correlation sequence,
R_k := Covariance{X_1, X_{1+k} } / Variance{X_1}
, we have
ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]
This function estimates the above by first estimating the auto-correlation.
Since R_k
must be estimated using only N - k
samples, it becomes
progressively noisier for larger k
. For this reason, the summation over
R_k
should be truncated at some number filter_beyond_lag < N
. This
function provides two methods to perform this truncation.
filter_threshold
-- since many MCMC methods generate chains whereR_k > 0
, a reasonable criterion is to truncate at the first index where the estimated auto-correlation becomes negative. This method does not estimate theESS
of super-efficient chains (whereESS > N
) correctly.filter_beyond_positive_pairs
-- reversible MCMC chains produce an auto-correlation sequence with the property that pairwise sums of the elements of that sequence are positive [Geyer][1], i.e.R_{2k} + R_{2k + 1} > 0
fork in {0, ..., N/2}
. Deviations are only possible due to noise. This method truncates the auto-correlation sequence where the pairwise sums become non-positive.
The arguments filter_beyond_lag
, filter_threshold
and
filter_beyond_positive_pairs
are filters intended to remove noisy tail terms
from R_k
. You can combine filter_beyond_lag
with filter_threshold
or
filter_beyond_positive_pairs. E.g., combining
filter_beyond_lagand
filter_beyond_positive_pairsmeans that terms are removed if they were to be
filtered under the
filter_beyond_lagOR
filter_beyond_positive_pairs`
criteria.
This function can also compute cross-chain ESS following
[Vehtari et al. (2021)][2] by specifying the cross_chain_dims
argument.
Cross-chain ESS takes into account the cross-chain variance to reduce the ESS
in cases where the chains are not mixing well. In general, this will be a
smaller number than computing the ESS for individual chains and then summing
them. In an extreme case where the chains have fallen into K non-mixing modes,
this function will return ESS ~ K. Even when chains are mixing well it is
still preferrable to compute cross-chain ESS via this method because it will
reduce the noise in the estimate of R_k
, reducing the need for truncation.
Raises | |
---|---|
ValueError
|
If states and filter_threshold or states and
filter_beyond_lag are both structures of different shapes.
|
ValueError
|
If cross_chain_dims is not None and there are less than 2
chains.
|
Examples
We use ESS to estimate standard error.
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
# Get 1000 states from one chain.
states = tfp.mcmc.sample_chain(
num_burnin_steps=200,
num_results=1000,
current_state=tf.constant([0., 0.]),
trace_fn=None,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target.log_prob,
step_size=0.05,
num_leapfrog_steps=20))
print(states.shape)
==> (1000, 2)
ess = effective_sample_size(states, filter_beyond_positive_pairs=True)
print(ess.shape)
==> (2,)
mean, variance = tf.nn.moments(states, axes=0)
standard_error = tf.sqrt(variance / ess)
References
[1]: Charles J. Geyer, Practical Markov chain Monte Carlo (with discussion). Statistical Science, 7:473-511, 1992.
[2]: Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter, Paul-Christian Bürkner. Rank-normalization, folding, and localization: An improved R-hat for assessing convergence of MCMC, 2021. Bayesian analysis, 16(2):667-718.