Computes diagonal of the Jacobian matrix of ys=fn(xs) wrt xs.

If ys is a tensor or a list of tensors of the form (ys_1, .., ys_n) and xs is of the form (xs_1, .., xs_n), the function jacobians_diag computes the diagonal of the Jacobian matrix, i.e., the partial derivatives (dys_1/dxs_1,.., dys_n/dxs_n). For definition details, see


Diagonal Hessian of the log-density of a 3D Gaussian distribution

In this example we sample from a standard univariate normal distribution using MALA with step_size equal to 0.75.

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

tfd = tfp.distributions

dtype = np.float32
with tf.Session(graph=tf.Graph()) as sess:
  true_mean = dtype([0, 0, 0])
  true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]])
  chol = tf.linalg.cholesky(true_cov)
  target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

  # Assume that the state is passed as a list of tensors `x` and `y`.
  # Then the target function is defined as follows:
  def target_fn(x, y):
    # Stack the input tensors together
    z = tf.concat([x, y], axis=-1) - true_mean
    return target.log_prob(z)

  sample_shape = [3, 5]
  state = [tf.ones(sample_shape + [2], dtype=dtype),
           tf.ones(sample_shape + [1], dtype=dtype)]
  fn_val = target_fn(*state)
  grad_fn = tfe.gradients_function(target_fn)
  if tfe.executing_eagerly():
    grads = grad_fn(*state)
    grads = tf.gradients(fn_val, state)

  # We can either pass the `sample_shape` of the `state` or not, which impacts
  # computational speed of `diag_jacobian`
  _, diag_jacobian_shape_passed = diag_jacobian(
      xs=state, ys=grads, sample_shape=tf.shape(fn_val))
  _, diag_jacobian_shape_none = diag_jacobian(
      xs=state, ys=grads)

  diag_jacobian_shape_passed_ =
  diag_jacobian_shape_none_ =

print('hessian computed through `diag_jacobian`, sample_shape passed: ',
      np.concatenate(diag_jacobian_shape_passed_, -1))
print('hessian computed through `diag_jacobian`, sample_shape skipped',
      np.concatenate(diag_jacobian_shape_none_, -1))


  • xs: Tensor or a python list of Tensors of real-like dtypes and shapes sample_shape + event_shape_i, where event_shape_i can be different for different tensors.
  • ys: Tensor or a python list of Tensors of the same dtype as xs. Must broadcast with the shape of xs. Can be omitted if fn is provided.
  • sample_shape: A common sample_shape of the input tensors of xs. If not, provided, assumed to be [1], which may result in a slow performance of jacobians_diag.
  • fn: Python callable that takes xs as an argument (or *xs, if it is a list) and returns ys. Might be skipped if ys is provided and tf.enable_eager_execution() is disabled.
  • parallel_iterations: int that specifies the allowed number of coordinates of the input tensor xs, for which the partial derivatives dys_i/dxs_i can be computed in parallel.
  • name: Python str name prefixed to Ops created by this function. Default value: None (i.e., "diag_jacobian").


  • ys: a list, which coincides with the input ys, when provided. If the input ys is None, fn(*xs) gets computed and returned as a list.
  • jacobians_diag_res: a Tensor or a Python list of Tensors of the same dtypes and shapes as the input xs. This is the diagonal of the Jacobian of ys wrt xs.


  • ValueError: if lists xs and ys have different length or both ys and fn are None, or fn is None in the eager execution mode.