tfp.substrates.jax.math.log_gamma_difference

Returns lgamma(y) - lgamma(x + y), accurately.

This is more accurate than subtracting lgammas directly because lgamma grows as x log(x) - x + o(x), and thus subtracting the value of lgamma for two close, large arguments incurs catastrophic cancellation.

When y >= 8, the method is to partition lgamma into the Stirling approximation and the correction log_gamma_correction, symbolically cancel the former, and compute and subtract the latter.

x Floating-point Tensor. x should be non-negative, and elementwise no more than y.
y Floating-point Tensor. y should be positive.
name Optional Python str naming the operation.

lgamma_diff Floating-point Tensor, the difference lgamma(y) - lgamma(x+y), computed elementwise.