Announcing the TensorFlow Dev Summit 2020 Learn more


View source on GitHub

Estimate a lower bound on effective sample size for each independent chain.


Roughly speaking, "effective sample size" (ESS) is the size of an iid sample with the same variance as state.

More precisely, given a stationary sequence of possibly correlated random variables X_1, X_2, ..., X_N, identically distributed, ESS is the number such that

Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.

If the sequence is uncorrelated, ESS = N. If the sequence is positively auto-correlated, ESS will be less than N. If there are negative correlations, then ESS can exceed N.


  • states: Tensor or list of Tensor objects. Dimension zero should index identically distributed states.
  • filter_threshold: Tensor or list of Tensor objects. Must broadcast with state. The sequence of auto-correlations is truncated after the first appearance of a term less than filter_threshold. Setting to None means we use no threshold filter. Since |R_k| <= 1, setting to any number less than -1 has the same effect. Ignored if filter_beyond_positive_pairs is True.
  • filter_beyond_lag: Tensor or list of Tensor objects. Must be int-like and scalar valued. The sequence of auto-correlations is truncated to this length. Setting to None means we do not filter based on the size of lags.
  • filter_beyond_positive_pairs: Python boolean. If True, only consider the initial auto-correlation sequence where the pairwise sums are positive.
  • name: String name to prepend to created ops.


  • ess: Tensor or list of Tensor objects. The effective sample size of each component of states. Shape will be states.shape[1:].


  • ValueError: If states and filter_threshold or states and filter_beyond_lag are both lists with different lengths.


We use ESS to estimate standard error.

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])

# Get 1000 states from one chain.
states = tfp.mcmc.sample_chain(
    current_state=tf.constant([0., 0.]),
==> (1000, 2)

ess = effective_sample_size(states, filter_beyond_positive_pairs=True)
==> Shape (2,) Tensor

mean, variance = tf.nn.moments(states, axis=0)
standard_error = tf.sqrt(variance / ess)

Some math shows that, with R_k the auto-correlation sequence, R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}, we have

ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]

This function estimates the above by first estimating the auto-correlation. Since R_k must be estimated using only N - k samples, it becomes progressively noisier for larger k. For this reason, the summation over R_k should be truncated at some number filter_beyond_lag < N. This function provides two methods to perform this truncation.

  • filter_threshold -- since many MCMC methods generate chains where R_k > 0, a reasonable criterion is to truncate at the first index where the estimated auto-correlation becomes negative. This method does not estimate the ESS of super-efficient chains (where ESS > N) correctly.

  • filter_beyond_positive_pairs -- reversible MCMC chains produce an auto-correlation sequence with the property that pairwise sums of the elements of that sequence are positive 1. Deviations are only possible due to noise. This method truncates the auto-correlation sequence where the pairwise sums become non-positive.

The arguments filter_beyond_lag, filter_threshold and filter_beyond_positive_pairs are filters intended to remove noisy tail terms from R_k. You can combine filter_beyond_lag with filter_threshold or filter_beyond_positive_pairs. E.g., combiningfilter_beyond_lagandfilter_beyond_positive_pairsmeans that terms are removed if they were to be filtered under thefilter_beyond_lagORfilter_beyond_positive_pairs` criteria.


[1]: Geyer, C. J. Practical Markov chain Monte Carlo (with discussion). Statistical Science, 7:473-511, 1992.