tfp.substrates.jax.stats.assign_log_moving_mean_exp

Compute the log of the exponentially weighted moving mean of the exp.

If log_value is a draw from a stationary random variable, this function approximates log(E[exp(log_value)]), i.e., a weighted log-sum-exp. More precisely, a tf.Variable, moving_log_mean_exp, is updated by log_value using the following identity:

moving_log_mean_exp =
= log(decay exp(moving_log_mean_exp) + (1 - decay) exp(log_value))
= log(exp(moving_log_mean_exp + log(decay)) + exp(log_value + log1p(-decay)))
= moving_log_mean_exp
  + log(  exp(moving_log_mean_exp   - moving_log_mean_exp + log(decay))
        + exp(log_value - moving_log_mean_exp + log1p(-decay)))
= moving_log_mean_exp
  + log_sum_exp([log(decay), log_value - moving_log_mean_exp +
  log1p(-decay)]).

In addition to numerical stability, this formulation is advantageous because moving_log_mean_exp can be updated in a lock-free manner, i.e., using assign_add. (Note: the updates are not thread-safe; it's just that the update to the tf.Variable is presumed efficient due to being lock-free.)

log_value float-like Tensor representing a new (streaming) observation. Same shape as moving_log_mean_exp.
moving_log_mean_exp float-like Variable representing the log of the exponentially weighted moving mean of the exp. Same shape as log_value.
zero_debias_count int-like tf.Variable representing the number of times this function has been called on streaming input (not the number of reduced values used in this functions computation). When not None (the default) the returned values for moving_mean and moving_variance are "zero debiased", i.e., corrected for their presumed all zeros intialization. Note: the tf.Variables moving_mean and moving_variance always store the unbiased calculation, regardless of setting this argument. To obtain unbiased calculations from these tf.Variables, see tfp.stats.moving_mean_variance_zero_debiased. Default value: None (i.e., no zero debiasing calculation is made).
decay A float-like Tensor representing the moving mean decay. Typically close to 1., e.g., 0.99. Default value: 0.99.
name Python str prepended to op names created by this function. Default value: None (i.e., 'assign_log_moving_mean_exp').

moving_log_mean_exp A reference to the input 'Variable' tensor with the log_value-updated log of the exponentially weighted moving mean of exp.

TypeError if moving_log_mean_exp does not have float type dtype.
TypeError if moving_log_mean_exp, log_value, decay have different base_dtype.