TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.experimental.substrates.jax.math.linalg.cholesky_concat

View source on GitHub

Concatenates chol @ chol.T with additional rows and columns.

tfp.experimental.substrates.jax.math.linalg.cholesky_concat(
    chol,
    cols,
    name=None
)

This operation is conceptually identical to:

def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
  mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
  # Concat columns.
  mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
  # Concat rows.
  mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
  return tf.linalg.cholesky(mat)

but whereas cholesky_concat_slow would cost O(z**3) work, cholesky_concat only costs O(z**2 + m**3) work.

The resulting (implicit) matrix must be symmetric and positive definite. Thus, the bottom right m x m must be self-adjoint, and we do not require a separate rows argument (which can be inferred from conj(cols.T)).

Args:

  • chol: Cholesky decomposition of mat = chol @ chol.T.
  • cols: The new columns whose first n rows we would like concatenated to the right of mat = chol @ chol.T, and whose conjugate transpose we would like concatenated to the bottom of concat(mat, cols[:n,:]). A Tensor with final dims (n+m, m). The first n rows are the top right rectangle (their conjugate transpose forms the bottom left), and the bottom m x m is self-adjoint.
  • name: Optional name for this op.

Returns:

  • chol_concat: The Cholesky decomposition of:
[ [ mat  cols[:n, :] ]
  [   conj(cols.T)   ] ]