View source on GitHub |
Compute one update to the exponentially weighted moving mean and variance.
tfp.substrates.jax.stats.assign_moving_mean_variance(
value,
moving_mean,
moving_variance=None,
zero_debias_count=None,
decay=0.99,
axis=(),
name=None
)
The value
updated exponentially weighted moving moving_mean
and
moving_variance
are conceptually given by the following recurrence
relations ([Welford (1962)][1]):
new_mean = old_mean + (1 - decay) * (value - old_mean)
new_var = old_var + (1 - decay) * (value - old_mean) * (value - new_mean)
This function implements the above recurrences in a numerically stable manner
and also uses the assign_add
op to allow concurrent lockless updates to the
supplied variables.
For additional references see
John D. Cook's Blog,
whereas we use 1 - decay = 1 / k
, and
[Finch (2009; Eq. 143)][2], whereas we use 1 - decay = alpha
.
Since variables that are initialized to a 0
value will be 0
biased,
providing zero_debias_count
triggers scaling the moving_mean
and
moving_variance
by the factor of 1 - decay ** (zero_debias_count + 1)
.
For more details, see tfp.stats.moving_mean_variance_zero_debiased
.
Args | |
---|---|
value
|
float -like Tensor representing one or more streaming
observations. When axis is non-empty value is reduced (by mean) for
updated moving_mean and moving-variance . Presumed to have same shape
as moving_mean and moving_variance .
|
moving_mean
|
float -like tf.Variable representing the exponentially
weighted moving mean. Same shape as moving_variance and value . This
function presumes the tf.Variable was created with all zero initial
value(s).
|
moving_variance
|
float -like tf.Variable representing the exponentially
weighted moving variance. Same shape as moving_mean and value . This
function presumes the tf.Variable was created with all zero initial
value(s).
Default value: None (i.e., no moving variance is computed).
|
zero_debias_count
|
int -like tf.Variable representing the number of times
this function has been called on streaming input (not the number of
reduced values used in this functions computation). When not None (the
default) the returned values for moving_mean and moving_variance are
"zero debiased", i.e., corrected for their presumed all zeros
intialization. Note: the tf.Variable s moving_mean and
moving_variance always store the unbiased calculation, regardless of
setting this argument. To obtain unbiased calculations from these
tf.Variable s, see tfp.stats.moving_mean_variance_zero_debiased .
Default value: None (i.e., no zero debiasing calculation is made).
|
decay
|
A float -like Tensor representing the moving mean decay. Typically
close to 1. , e.g., 0.99 .
Default value: 0.99 .
|
axis
|
The dimensions to reduce. If () (the default) no dimensions are
reduced. If None all dimensions are reduced. Must be in the range
[-rank(value), rank(value)) .
Default value: () (i.e., no reduction is made).
|
name
|
Python str prepended to op names created by this function.
Default value: None (i.e., 'assign_moving_mean_variance').
|
Raises | |
---|---|
TypeError
|
if moving_mean does not have float type dtype .
|
TypeError
|
if moving_mean , moving_variance , value , decay have
different base_dtype .
|
Examples
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
d = tfd.MultivariateNormalTriL(
loc=[-1., 1.],
scale_tril=tf.linalg.cholesky([[0.75, 0.05],
[0.05, 0.5]]))
d.mean()
# ==> [-1., 1.]
d.variance()
# ==> [0.75, 0.5]
moving_mean = tf.Variable(tf.zeros(2))
moving_variance = tf.Variable(tf.zeros(2))
zero_debias_count = tf.Variable(0)
for _ in range(100):
m, v = tfp.stats.assign_moving_mean_variance(
value=d.sample(3),
moving_mean=moving_mean,
moving_variance=moving_variance,
zero_debias_count=zero_debias_count,
decay=0.99,
axis=-2)
print(m.numpy(), v.numpy())
# ==> [-1.0334632 0.9545268] [0.8126194 0.5118788]
# ==> [-1.0293456 0.96070296] [0.8115873 0.50947404]
# ...
# ==> [-1.025172 0.96351 ] [0.7142789 0.48570773]
m1, v1 = tfp.stats.moving_mean_variance_zero_debiased(
moving_mean,
moving_variance,
zero_debias_count,
decay=0.99)
print(m.numpy(), v.numpy())
# ==> [-1.025172 0.96351 ] [0.7142789 0.48570773]
assert(all(m == m1))
assert(all(v == v1))
References
[1] B. P. Welford. Note on a Method for Calculating Corrected Sums of Squares and Products. Technometrics, Vol. 4, No. 3 (Aug., 1962), p419-20. http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.302.7503&rep=rep1&type=pdf http://www.jstor.org/stable/1266577
[2]: Tony Finch. Incremental calculation of weighted mean and variance. Technical Report, 2009. http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf