View source on GitHub |
Computes log(abs(sum(weight * exp(elements across tensor dimensions))))
.
tfp.substrates.jax.math.reduce_weighted_logsumexp(
logx,
w=None,
axis=None,
keep_dims=False,
return_sign=False,
experimental_named_axis=None,
experimental_allow_all_gather=False,
name=None
)
If all weights w
are known to be positive, it is more efficient to directly
use reduce_logsumexp
, i.e., tf.reduce_logsumexp(logx + tf.log(w))
is more
efficient than du.reduce_weighted_logsumexp(logx, w)
.
Reduces input_tensor
along the dimensions given in axis
.
Unless keep_dims
is true, the rank of the tensor is reduced by 1 for each
entry in axis
. If keep_dims
is true, the reduced dimensions
are retained with length 1.
If axis
has no entries, all dimensions are reduced, and a
tensor with a single element is returned.
This function is more numerically stable than log(sum(w * exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs.
For example:
x = tf.constant([[0., 0, 0],
[0, 0, 0]])
w = tf.constant([[-1., 1, 1],
[1, 1, 1]])
du.reduce_weighted_logsumexp(x, w)
# ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
du.reduce_weighted_logsumexp(x, w, axis=0)
# ==> [log(-1+1), log(1+1), log(1+1)]
du.reduce_weighted_logsumexp(x, w, axis=1)
# ==> [log(-1+1+1), log(1+1+1)]
du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
# ==> [[log(-1+1+1)], [log(1+1+1)]]
du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
# ==> log(-1+5)
Returns | |
---|---|
lswe
|
The log(abs(sum(weight * exp(x)))) reduced tensor.
|
sign
|
(Optional) The sign of sum(weight * exp(x)) .
|