tfp.substrates.jax.stats.log_loomean_exp

Computes the log-leave-one-out-mean of exp(logx).

logx Floating-type Tensor representing log(x) where x is some positive value.
axis The dimensions to sum across. If None (the default), reduces all dimensions. Must be in the range [-rank(logx), rank(logx)]. Default value: None (i.e., reduce over all dims).
keepdims If true, retains reduced dimensions with length 1. Default value: False (i.e., keep all dims in log_mean_x).
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "log_loomean_exp").

log_loomean_exp Tensor with same shape and dtype as logx representing the natural-log of the mean of exp(logx) except that the element logx[i] is removed.
log_mean_x logx.dtype Tensor corresponding to the natural-log of the arithmetic mean of x. Has reduced shape of logx (per axis and keepdims).