tfp.experimental.substrates.jax.math.linalg.lu_matrix_inverse

View source on GitHub

Computes a matrix inverse given the matrix's LU decomposition.

tfp.experimental.substrates.jax.math.linalg.lu_matrix_inverse(
    lower_upper,
    perm,
    validate_args=False,
    name=None
)

This op is conceptually identical to,

inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
tf.assert_near(tf.matrix_inverse(X), inv_X)
# ==> True

Args:

  • lower_upper: lu as returned by tf.linalg.lu, i.e., if matmul(P, matmul(L, U)) = X then lower_upper = L + U - eye.
  • perm: p as returned by tf.linag.lu, i.e., if matmul(P, matmul(L, U)) = X then perm = argmax(P).
  • validate_args: Python bool indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when validate_args=True. Default value: False (i.e., don't validate arguments).
  • name: Python str name given to ops managed by this object. Default value: None (i.e., 'lu_matrix_inverse').

Returns:

  • inv_x: The matrix_inv, i.e., tf.matrix_inverse(tfp.math.lu_reconstruct(lu, perm)).

Examples

import numpy as np
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

x = [[[3., 4], [1, 2]],
     [[7., 8], [3, 4]]]
inv_x = tfp.math.lu_matrix_inverse(*tf.linalg.lu(x))
tf.assert_near(tf.matrix_inverse(x), inv_x)
# ==> True