View source on GitHub |
Creates a vector from a (batch of) triangular matrix.
tfp.substrates.jax.math.fill_triangular_inverse(
x, upper=False, name=None
)
The vector is created from the lower-triangular or upper-triangular portion
depending on the value of the parameter upper
.
If x.shape
is [b1, b2, ..., bB, n, n]
then the output shape is
[b1, b2, ..., bB, d]
where d = n (n + 1) / 2
.
Example:
fill_triangular_inverse(
[[4, 0, 0],
[6, 5, 0],
[3, 2, 1]])
# ==> [1, 2, 3, 4, 5, 6]
fill_triangular_inverse(
[[1, 2, 3],
[0, 5, 6],
[0, 0, 4]], upper=True)
# ==> [1, 2, 3, 4, 5, 6]
Returns | |
---|---|
flat_tril
|
(Batch of) vector-shaped Tensor representing vectorized lower
(or upper) triangular elements from x .
|