tfp.substrates.jax.math.value_and_gradient

Computes f(*args) and its gradients wrt to *args.