Computes log(cumsum(exp(x))).
tfp.substrates.jax.math.log_cumsum_exp(
x, axis=-1, name=None
)
This is a pure-TF implementation of tf.math.cumulative_logsumexp
; unlike
the built-in op, it supports XLA compilation. It uses a similar algorithmic
technique (parallel prefix sum) as the built-in op, so it has similar numerics
and asymptotic performace. However, this implemenentation currently has higher
overhead, so it is significantly slower on smaller inputs (n < 10000
).
Args |
x
|
the Tensor to sum over.
|
axis
|
int Tensor axis to sum over.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'cumulative_logsumexp' ).
|
Returns |
cumulative_logsumexp
|
Tensor of the same shape as x .
|