View source on GitHub |
Estimate standard deviation using samples.
tfp.substrates.jax.stats.stddev(
x, sample_axis=0, keepdims=False, name=None
)
Given N
samples of scalar valued random variable X
, standard deviation may
be estimated as
Stddev[X] := Sqrt[Var[X]],
Var[X] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(X_n - Xbar)},
Xbar := N^{-1} sum_{n=1}^N X_n
x = tf.random.stateless_normal(shape=(100, 2, 3))
# stddev[i, j] is the sample standard deviation of the (i, j) batch member.
stddev = tfp.stats.stddev(x, sample_axis=0)
Scaling a unit normal by a standard deviation produces normal samples with that standard deviation.
observed_data = read_data_samples(...)
stddev = tfp.stats.stddev(observed_data)
# Make fake_data with the same standard deviation as observed_data.
fake_data = stddev * tf.random.stateless_normal(shape=(100,))
Notice we divide by N
(the numpy default), which does not create NaN
when N = 1
, but is slightly biased.
Returns | |
---|---|
stddev
|
A Tensor of same dtype as the x , and rank equal to
rank(x) - len(sample_axis)
|