|View source on GitHub|
Reduces the input tensor along the given axis using Kahan summation.
tfp.substrates.jax.math.reduce_kahan_sum( input_tensor, axis=None, keepdims=False, name=None )
Returns both the total and the correction term, as a
namedtuple, so that a
more accurate sum may be written as
total - correction.
A practical use-case is computing the difference of two large (magnitude) sums
we expect to be nearly equal. If instead we take their difference as
(s0.total - s1.total) - (s0.correction - s1.correction), we can retain more
precision in computing their difference.
||The tensor to sum.|
||Optional name for ops in scope.|