Attend the Women in ML Symposium on December 7 Register now

tfp.experimental.distributions.marginal_fns.tfp_custom_gradient.custom_gradient

Stay organized with collections Save and categorize content based on your preferences.

Decorates a function and adds custom derivatives.

TF only supports VJPs, so we decorate with tf.custom_gradient.

JAX supports either JVP or VJP. If a custom JVP is provided, then JAX can transpose to derive a VJP rule. Therefore we prefer jvp_fn if given, but fall back to the vjp functions otherwise.

vjp_fwd A function (args) => (output, auxiliaries).
vjp_bwd A function (auxiliaries, output_gradient) => nondiff_args_gradients. None gradients will be inserted into the correct positions for nondiff_argnums.
jvp_fn A function (nondiff_args, primals, tangents) => (primal_out, tangent_out).
nondiff_argnums Tuple of argument indices which are not differentiable.

A decorator to be applied to a function f(*args) => output.