View source on GitHub |
TensorFlow Probability math functions.
Modules
hypergeometric
module: Implements hypergeometric functions in TensorFlow.
ode
module: TensorFlow Probability ODE solvers.
psd_kernels
module: Positive-semidefinite kernels package.
Classes
class MinimizeTraceableQuantities
: Namedtuple of quantities that may be traced from tfp.math.minimize
.
Functions
atan_difference(...)
: Difference of arctan(x) and arctan(y).
batch_interp_rectilinear_nd_grid(...)
: Multi-linear interpolation on a rectilinear grid.
batch_interp_regular_1d_grid(...)
: Linear 1-D
interpolation on a regular (constant spacing) grid.
batch_interp_regular_nd_grid(...)
: Multi-linear interpolation on a regular (constant spacing) grid.
bessel_iv_ratio(...)
: Computes I_{v} (z) / I_{v - 1} (z)
in a numerically stable way.
bessel_ive(...)
: Computes exponentially scaled modified Bessel function of the first kind.
bessel_kve(...)
: Computes exponentially scaled modified Bessel function of the 2nd kind.
betainc(...)
: Computes the regularized incomplete beta function element-wise.
betaincinv(...)
: Computes the inverse of tfp.math.betainc
with respect to x
.
bracket_root(...)
: Finds bounds that bracket a root of the objective function.
cholesky_concat(...)
: Concatenates chol @ chol.T
with additional rows and columns.
cholesky_update(...)
: Returns cholesky of chol @ chol.T + multiplier * u @ u.T.
clip_by_value_preserve_gradient(...)
: Clips values to a specified min and max while leaving gradient unaltered.
custom_gradient(...)
: Embeds a custom gradient into a Tensor
.
dawsn(...)
: Computes Dawson's integral element-wise.
diag_jacobian(...)
: Computes diagonal of the Jacobian matrix of ys=fn(xs)
wrt xs
.
erfcinv(...)
: Computes the inverse of tf.math.erfc
of z
element-wise.
erfcx(...)
: Computes the scaled complementary error function exp(x**) * erfc(x).
fill_triangular(...)
: Creates a (batch of) triangular matrix from a vector of inputs.
fill_triangular_inverse(...)
: Creates a vector from a (batch of) triangular matrix.
find_root_chandrupatla(...)
: Finds root(s) of a scalar function using Chandrupatla's method.
find_root_secant(...)
: Finds root(s) of a function of single variable using the secant method.
gram_schmidt(...)
: Implementation of the modified Gram-Schmidt orthonormalization algorithm.
hpsd_logdet(...)
: Computes log|det(matrix)|
, where matrix
is a HPSD matrix.
hpsd_quadratic_form_solve(...)
: Computes rhs^T matrix^-1 rhs
, where matrix
is HPSD.
hpsd_quadratic_form_solvevec(...)
: Computes rhs^T matrix^-1 rhs
, where matrix
is HPSD.
hpsd_solve(...)
: Computes matrix^-1 rhs
, where matrix
is HPSD.
hpsd_solvevec(...)
: Computes matrix^-1 rhs
, where matrix
is HPSD.
igammacinv(...)
: Computes the inverse to tf.math.igammac
with respect to p
.
igammainv(...)
: Computes the inverse to tf.math.igamma
with respect to p
.
interp_regular_1d_grid(...)
: Linear 1-D
interpolation on a regular (constant spacing) grid.
lambertw(...)
: Computes Lambert W of z
element-wise.
lambertw_winitzki_approx(...)
: Computes Winitzki approximation to Lambert W function at z >= -1/exp(1).
lbeta(...)
: Returns log(Beta(x, y)).
log1mexp(...)
: Compute log(1 - exp(-|x|))
elementwise in a numerically stable way.
log1psquare(...)
: Numerically stable calculation of log(1 + x**2)
for small or large |x|
.
log_add_exp(...)
: Computes log(exp(x) + exp(y))
in a numerically stable way.
log_bessel_ive(...)
: Computes log(tfp.math.bessel_ive(v, z))
.
log_bessel_kve(...)
: Computes log(tfp.math.bessel_kve(v, z))
.
log_combinations(...)
: Log multinomial coefficient.
log_cosh(...)
: Compute log(cosh(x))
in a numerically stable way.
log_cumsum_exp(...)
: Computes log(cumsum(exp(x))).
log_gamma_correction(...)
: Returns the error of the Stirling approximation to lgamma(x) for x >= 8.
log_gamma_difference(...)
: Returns lgamma(y) - lgamma(x + y), accurately.
log_sub_exp(...)
: Compute log(exp(max(x, y)) - exp(min(x, y)))
in a numerically stable way.
logerfc(...)
: Computes the logarithm of tf.math.erfc
of x
element-wise.
logerfcx(...)
: Computes the logarithm of tfp.math.erfcx
of x
element-wise.
low_rank_cholesky(...)
: Computes a low-rank approximation to the Cholesky decomposition.
lu_matrix_inverse(...)
: Computes a matrix inverse given the matrix's LU decomposition.
lu_reconstruct(...)
: The inverse LU decomposition, X == lu_reconstruct(*tf.linalg.lu(X))
.
lu_solve(...)
: Solves systems of linear eqns A X = RHS
, given LU factorizations.
minimize(...)
: Minimize a loss function using a provided optimizer.
minimize_stateless(...)
: Minimize a loss expressed as a pure function of its parameters.
owens_t(...)
: Computes Owen's T function of h
and a
element-wise.
pivoted_cholesky(...)
: Computes the (partial) pivoted cholesky decomposition of matrix
.
reduce_kahan_sum(...)
: Reduces the input tensor along the given axis using Kahan summation.
reduce_log_harmonic_mean_exp(...)
: Computes log(1 / mean(1 / exp(input_tensor)))
.
reduce_logmeanexp(...)
: Computes log(mean(exp(input_tensor)))
.
reduce_weighted_logsumexp(...)
: Computes log(abs(sum(weight * exp(elements across tensor dimensions))))
.
round_exponential_bump_function(...)
: Function supported on [-1, 1], smooth on the real line, with a round top.
scan_associative(...)
: Perform a scan with an associative binary operation, in parallel.
secant_root(...)
: Finds root(s) of a function of single variable using the secant method.
smootherstep(...)
: Computes a sigmoid-like interpolation function on the unit-interval.
soft_sorting_matrix(...)
: Computes a matrix representing a continuous relaxation of sorting.
soft_threshold(...)
: Soft Thresholding operator.
softplus_inverse(...)
: Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
sparse_or_dense_matmul(...)
: Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.
sparse_or_dense_matvecmul(...)
: Returns (batched) matmul of a (sparse) matrix with a column vector.
sqrt1pm1(...)
: Compute sqrt(x + 1) - 1
elementwise in a numerically stable way.
trapz(...)
: Integrate y(x) on the specified axis using the trapezoidal rule.
value_and_gradient(...)
: Computes f(*args)
and its gradients wrt to *args
.