tfp.experimental.substrates.jax.math.reduce_logmeanexp

Computes log(mean(exp(input_tensor))).

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1.

If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.

This function is more numerically stable than log(reduce_mean(exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs.

input_tensor The tensor to reduce. Should have numeric type.
axis The dimensions to reduce. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).
keepdims Boolean. Whether to keep the axis as singleton dimensions. Default value: False (i.e., squeeze the reduced dimensions).
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'reduce_logmeanexp').

log_mean_exp The reduced tensor.