tfp.experimental.substrates.jax.stats.assign_moving_mean_variance

Compute one update to the exponentially weighted moving mean and variance.

The value updated exponentially weighted moving moving_mean and moving_variance are conceptually given by the following recurrence relations ([Welford (1962)][1]):

new_mean = old_mean + (1 - decay) * (value - old_mean)
new_var  = old_var  + (1 - decay) * (value - old_mean) * (value - new_mean)

This function implements the above recurrences in a numerically stable manner and also uses the assign_add op to allow concurrent lockless updates to the supplied variables.

For additional references see [this John D. Cook blog post] https://www.johndcook.com/blog/standard_deviation/ and [Finch (2009; Eq. 143)]2.

Since variables that are initialized to a 0 value will be 0 biased, providing zero_debias_count triggers scaling the moving_mean and moving_variance by the factor of 1 - decay ** (zero_debias_count + 1). For more details, see tfp.stats.moving_mean_variance_zero_debiased.

value float-like Tensor representing one or more streaming observations. When axis is non-empty value is reduced (by mean) for updated moving_mean and moving-variance. Presumed to have same shape as moving_mean and moving_variance.
moving_mean float-like tf.Variable representing the exponentially weighted moving mean. Same shape as moving_variance and value. This function presumes the tf.Variable was created with all zero initial value(s).
moving_variance float-like tf.Variable representing the exponentially weighted moving variance. Same shape as moving_mean and value. This function presumes the tf.Variable was created with all zero initial value(s). Default value: None (i.e., no moving variance is computed).
zero_debias_count int-like tf.Variable representing the number of times this function has been called on streaming input (not the number of reduced values used in this functions computation). When not None (the default) the returned values for moving_mean and moving_variance are "zero debiased", i.e., corrected for their presumed all zeros intialization. Note: the tf.Variables moving_mean and moving_variance always store the unbiased calculation, regardless of setting this argument. To obtain unbiased calculations from these tf.Variables, see tfp.stats.moving_mean_variance_zero_debiased. Default value: None (i.e., no zero debiasing calculation is made).
decay A float-like Tensor representing the moving mean decay. Typically close to 1., e.g., 0.99. Default value: 0.99.
axis The dimensions to reduce. If () (the default) no dimensions are reduced. If None all dimensions are reduced. Must be in the range [-rank(value), rank(value)). Default value: () (i.e., no reduction is made).
name Python str prepended to op names created by this function. Default value: None (i.e., 'assign_moving_mean_variance').

moving_mean The value-updated exponentially weighted moving mean. Debiased if zero_debias_count is not None.
moving_variance The value-updated exponentially weighted moving variance. Debiased if zero_debias_count is not None.

TypeError if moving_mean does not have float type dtype.
TypeError if moving_mean, moving_variance, value, decay have different base_dtype.

Examples

from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
tfd = tfp.distributions
d = tfd.MultivariateNormalTriL(
    loc=[-1., 1.],
    scale_tril=tf.linalg.cholesky([[0.75, 0.05],
                                   [0.05, 0.5]]))
d.mean()
# ==> [-1.,  1.]
d.variance()
# ==> [0.75, 0.5]
moving_mean = tf.Variable(tf.zeros(2))
moving_variance = tf.Variable(tf.zeros(2))
zero_debias_count = tf.Variable(0)
for _ in range(100):
  m, v = tfp.stats.assign_moving_mean_variance(
    value=d.sample(3),
    moving_mean=moving_mean,
    moving_variance=moving_variance,
    zero_debias_count=zero_debias_count,
    decay=0.99,
    axis=-2)
  print(m.numpy(), v.numpy())
# ==> [-1.0334632  0.9545268] [0.8126194 0.5118788]
# ==> [-1.0293456   0.96070296] [0.8115873  0.50947404]
# ...
# ==> [-1.025172  0.96351 ] [0.7142789  0.48570773]

m1, v1 = tfp.stats.moving_mean_variance_zero_debiased(
  moving_mean,
  moving_variance,
  zero_debias_count,
  decay=0.99)
print(m.numpy(), v.numpy())
# ==> [-1.025172  0.96351 ] [0.7142789  0.48570773]
assert(all(m == m1))
assert(all(v == v1))

References

[1] B. P. Welford. Note on a Method for Calculating Corrected Sums of Squares and Products. Technometrics, Vol. 4, No. 3 (Aug., 1962), p419-20. http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.302.7503&rep=rep1&type=pdf http://www.jstor.org/stable/1266577

[2]: Tony Finch. Incremental calculation of weighted mean and variance. Technical Report, 2009. http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf