tfp.substrates.jax.stats.variance

Estimate variance using samples.

Given N samples of scalar valued random variable X, variance may be estimated as

Var[X] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(X_n - Xbar)}
Xbar := N^{-1} sum_{n=1}^N X_n
x = tf.random.stateless_normal(shape=(100, 2, 3))

# var[i, j] is the sample variance of the (i, j) batch member of x.
var = tfp.stats.variance(x, sample_axis=0)

Notice we divide by N (the numpy default), which does not create NaN when N = 1, but is slightly biased.

x A numeric Tensor holding samples.
sample_axis Scalar or vector Tensor designating axis holding samples, or None (meaning all axis hold samples). Default value: 0 (leftmost dimension).
keepdims Boolean. Whether to keep the sample axis as singletons.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'variance').

var A Tensor of same dtype as the x, and rank equal to rank(x) - len(sample_axis)