Compute the log of the exponentially weighted moving mean of the exp.
tfp.substrates.jax.stats.assign_log_moving_mean_exp(
log_value,
moving_log_mean_exp,
zero_debias_count=None,
decay=0.99,
name=None
)
If log_value
is a draw from a stationary random variable, this function
approximates log(E[exp(log_value)])
, i.e., a weighted log-sum-exp. More
precisely, a tf.Variable
, moving_log_mean_exp
, is updated by log_value
using the following identity:
moving_log_mean_exp =
= log(decay exp(moving_log_mean_exp) + (1 - decay) exp(log_value))
= log(exp(moving_log_mean_exp + log(decay)) + exp(log_value + log1p(-decay)))
= moving_log_mean_exp
+ log( exp(moving_log_mean_exp - moving_log_mean_exp + log(decay))
+ exp(log_value - moving_log_mean_exp + log1p(-decay)))
= moving_log_mean_exp
+ log_sum_exp([log(decay), log_value - moving_log_mean_exp +
log1p(-decay)]).
In addition to numerical stability, this formulation is advantageous because
moving_log_mean_exp
can be updated in a lock-free manner, i.e., using
assign_add
. (Note: the updates are not thread-safe; it's just that the
update to the tf.Variable is presumed efficient due to being lock-free.)
Args |
log_value
|
float -like Tensor representing a new (streaming) observation.
Same shape as moving_log_mean_exp .
|
moving_log_mean_exp
|
float -like Variable representing the log of the
exponentially weighted moving mean of the exp. Same shape as log_value .
|
zero_debias_count
|
int -like tf.Variable representing the number of times
this function has been called on streaming input (not the number of
reduced values used in this functions computation). When not None (the
default) the returned values for moving_mean and moving_variance are
"zero debiased", i.e., corrected for their presumed all zeros
intialization. Note: the tf.Variable s moving_mean and
moving_variance always store the unbiased calculation, regardless of
setting this argument. To obtain unbiased calculations from these
tf.Variable s, see tfp.stats.moving_mean_variance_zero_debiased .
Default value: None (i.e., no zero debiasing calculation is made).
|
decay
|
A float -like Tensor representing the moving mean decay. Typically
close to 1. , e.g., 0.99 .
Default value: 0.99 .
|
name
|
Python str prepended to op names created by this function.
Default value: None (i.e., 'assign_log_moving_mean_exp').
|
Returns |
moving_log_mean_exp
|
A reference to the input 'Variable' tensor with the
log_value -updated log of the exponentially weighted moving mean of exp.
|
Raises |
TypeError
|
if moving_log_mean_exp does not have float type dtype .
|
TypeError
|
if moving_log_mean_exp , log_value , decay have different
base_dtype .
|