Help protect the Great Barrier Reef with TensorFlow on Kaggle

tfp.substrates.jax.math.reduce_weighted_logsumexp

Computes log(abs(sum(weight * exp(elements across tensor dimensions)))).

If all weights w are known to be positive, it is more efficient to directly use reduce_logsumexp, i.e., tf.reduce_logsumexp(logx + tf.log(w)) is more efficient than du.reduce_weighted_logsumexp(logx, w).

Reduces input_tensor along the dimensions given in axis. Unless keep_dims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keep_dims is true, the reduced dimensions are retained with length 1.

If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.

This function is more numerically stable than log(sum(w * exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs.

For example:

x = tf.constant([[0., 0, 0],
[0, 0, 0]])

w = tf.constant([[-1., 1, 1],
[1, 1, 1]])

du.reduce_weighted_logsumexp(x, w)
# ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)

du.reduce_weighted_logsumexp(x, w, axis=0)
# ==> [log(-1+1), log(1+1), log(1+1)]

du.reduce_weighted_logsumexp(x, w, axis=1)
# ==> [log(-1+1+1), log(1+1+1)]

du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
# ==> [[log(-1+1+1)], [log(1+1+1)]]

du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
# ==> log(-1+5)

logx The tensor to reduce. Should have numeric type.
w The weight tensor. Should have numeric type identical to logx.
axis The dimensions to reduce. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).
keep_dims If true, retains reduced dimensions with length 1.
return_sign If True, returns the sign of the result.
experimental_named_axis A str or list ofstraxis names to additionally reduce over. ProvidingNonewill not reduce over any axes. </td> </tr><tr> <td>name` A name for the operation (optional).

lswe The log(abs(sum(weight * exp(x)))) reduced tensor.
sign (Optional) The sign of sum(weight * exp(x)).

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]