TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Module: tfp.experimental.substrates.jax.math

View source on GitHub

JAX math.

Modules

generic module: Functions for generic calculations.

gradient module: Functions for computing gradients.

linalg module: Functions for common linear algebra operations.

numeric module: Numerically stable variants of common mathematical expressions.

Functions

fill_triangular(...): Creates a (batch of) triangular matrix from a vector of inputs.

fill_triangular_inverse(...): Creates a vector from a (batch of) triangular matrix.

log1psquare(...): Numerically stable calculation of log(1 + x**2) for small or large |x|.

log_add_exp(...): Computes log(exp(x) + exp(y)) in a numerically stable way.

log_combinations(...): Multinomial coefficient.

reduce_weighted_logsumexp(...): Computes log(abs(sum(weight * exp(elements across tensor dimensions)))).

softplus_inverse(...): Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).

value_and_gradient(...): Computes f(*xs) and its gradients wrt to *xs.