Embeds a custom gradient into a `Tensor`.

``````tfp.experimental.substrates.jax.math.custom_gradient(
fx,
gx,
x,
fx_gx_manually_stopped=False,
name=None
)
``````

This function works by clever application of `stop_gradient`. I.e., observe that:

``````h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
``````

is such that `h(x) == stop_gradient(f(x))` and `grad[h(x), x] == stop_gradient(g(x)).`

In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions.

Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but ```dh/dy = None```. This is because a `Tensor` cannot have only a portion of its gradient stopped. To circumvent this issue, one must manually `stop_gradient` the relevant portions of `f`, `g`. For example see the unit-test, `test_works_correctly_fx_gx_manually_stopped`.

#### Args:

• `fx`: `Tensor`. Output of function evaluated at `x`.
• `gx`: `Tensor` or list of `Tensor`s. Gradient of function at (each) `x`.
• `x`: `Tensor` or list of `Tensor`s. Args of evaluation for `f`.
• `fx_gx_manually_stopped`: Python `bool` indicating that `fx`, `gx` manually have `stop_gradient` applied.
• `name`: Python `str` name prefixed to Ops created by this function.

#### Returns:

• `fx`: Floating-type `Tensor` equal to `f(x)` but which has gradient `stop_gradient(g(x))`.