Computes diagonal of the Jacobian matrix of
xs, ys=None, sample_shape=None, fn=None, parallel_iterations=10, name=None
ys is a tensor or a list of tensors of the form
(ys_1, .., ys_n) and
xs is of the form
(xs_1, .., xs_n), the function
computes the diagonal of the Jacobian matrix, i.e., the partial derivatives
(dys_1/dxs_1,.., dys_n/dxs_n). For definition details, see
Diagonal Hessian of the log-density of a 3D Gaussian distribution
In this example we sample from a standard univariate normal
distribution using MALA with
step_size equal to 0.75.
from tensorflow_probability.python.internal.backend import jax as tf
import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax
import numpy as np
tfd = tfp.distributions
dtype = np.float32
with tf.Session(graph=tf.Graph()) as sess:
true_mean = dtype([0, 0, 0])
true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]])
chol = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)
# Assume that the state is passed as a list of tensors `x` and `y`.
# Then the target function is defined as follows:
def target_fn(x, y):
# Stack the input tensors together
z = tf.concat([x, y], axis=-1) - true_mean
sample_shape = [3, 5]
state = [tf.ones(sample_shape + , dtype=dtype),
tf.ones(sample_shape + , dtype=dtype)]
fn_val, grads = tfp.math.value_and_gradient(target_fn, state)
# We can either pass the `sample_shape` of the `state` or not, which impacts
# computational speed of `diag_jacobian`
_, diag_jacobian_shape_passed = diag_jacobian(
xs=state, ys=grads, sample_shape=tf.shape(fn_val))
_, diag_jacobian_shape_none = diag_jacobian(
diag_jacobian_shape_passed_ = sess.run(diag_jacobian_shape_passed)
diag_jacobian_shape_none_ = sess.run(diag_jacobian_shape_none)
print('hessian computed through `diag_jacobian`, sample_shape passed: ',
print('hessian computed through `diag_jacobian`, sample_shape skipped',
Tensor or a python
Tensors of real-like dtypes and shapes
event_shape_i can be different
for different tensors.
Tensor or a python
Tensors of the same dtype as
broadcast with the shape of
xs. Can be omitted if
fn is provided.
sample_shape of the input tensors of
xs. If not,
provided, assumed to be
, which may result in a slow performance of
Python callable that takes
xs as an argument (or
*xs, if it is a
list) and returns
ys. Might be skipped if
ys is provided and
tf.enable_eager_execution() is disabled.
int that specifies the allowed number of coordinates
of the input tensor
xs, for which the partial derivatives
can be computed in parallel.
str name prefixed to
Ops created by this function.
None (i.e., "diag_jacobian").
a list, which coincides with the input
ys, when provided.
If the input
ys is None,
fn(*xs) gets computed and returned as a list.
Tensor or a Python list of
Tensors of the same
dtypes and shapes as the input
xs. This is the diagonal of the Jacobian
of ys wrt xs.
ys have different length or both
fn is None in the eager execution mode.