View source on GitHub |
Estimate variance using samples.
tfp.substrates.jax.stats.variance(
x, sample_axis=0, keepdims=False, name=None
)
Given N
samples of scalar valued random variable X
, variance may
be estimated as
Var[X] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(X_n - Xbar)}
Xbar := N^{-1} sum_{n=1}^N X_n
x = tf.random.stateless_normal(shape=(100, 2, 3))
# var[i, j] is the sample variance of the (i, j) batch member of x.
var = tfp.stats.variance(x, sample_axis=0)
Notice we divide by N
(the numpy default), which does not create NaN
when N = 1
, but is slightly biased.
Returns | |
---|---|
var
|
A Tensor of same dtype as the x , and rank equal to
rank(x) - len(sample_axis)
|