tfp.substrates.jax.stats.log_soosum_exp

Computes the log-swap-one-out-sum of exp(logx).

The swapped out element logx[i] is replaced with the log-leave-i-out geometric mean of logx.

logx Floating-type Tensor representing log(x) where x is some positive value.
axis The dimensions to sum across. If None (the default), reduces all dimensions. Must be in the range [-rank(logx), rank(logx)]. Default value: None (i.e., reduce over all dims).
keepdims If true, retains reduced dimensions with length 1. Default value: False (i.e., keep all dims in log_mean_x).
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "log_soomean_exp").

log_soomean_x logx.dtype Tensor characterized by the natural-log of the sum of xexcept that the elementlogx[i]is replaced with the log of the leave-i-out Geometric-average. The sum of the gradient oflog_soosum_xisn, i.e., the number of reduced elements. Mathematicallylog_soomean_x` is,

log_soomean_x[i] = log(Avg{h[j ; i] : j=0, ..., m-1})
h[j ; i] = { u[j]                              j!=i
           { GeometricAverage{u[k] : k != i}   j==i

log_sum_x logx.dtype Tensor corresponding to the natural-log of the average of x. The sum of the gradient of log_mean_x is 1. Has reduced shape of logx (per axis and keepdims).