ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Decorator to define a function with a custom gradient.

This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.

For example, consider the following function that commonly occurs in the computation of cross entropy and log likelihoods:

def log1pexp(x):
  return tf.math.log(1 + tf.exp(x))

Due to numerical instability, the gradient of this function evaluated at x=100 is NaN. For example:

x = tf.constant(100.)
y = log1pexp(x)
dy_dx = tf.gradients(y, x) # Will be NaN when evaluated.

The gradient expression can be analytically simplified to provide numerical stability:

def log1pexp(x):
  e = tf.exp(x)
  def grad(upstream):
    return upstream * (1 - 1 / (1 + e))
  return tf.math.log(1 + e), grad

With this definition, the gradient dy_dx at x = 100 will be correctly evaluated as 1.0.

The variable upstream is defined as the upstream gradient. i.e. the gradient from all the layers or functions originating from this layer. The above example has no upstream functions, therefore upstream = dy/dy = 1.0.

Assume that x_i is log1pexp in the forward pass x_1 = x_1(x_0), x_2 = x_2(x_1), ..., x_i = x_i(x_i-1), ..., x_n = x_n(x_n-1). By chain rule we know that dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i/dx_i-1 * ... * dx_1/dx_0.

In this case the gradient of our current function defined as dx_i/dx_i-1 = (1 - 1 / (1 + e)). The upstream gradient upstream would be dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i. The upstream gradient multiplied by the current gradient is then passed downstream.

In case the function takes multiple variables as input, the grad function must also return the same number of variables. We take the function z = x * y as an example.

def bar(x, y):
  def grad(upstream):
    dz_dx = y
    dz_dy = x
    return upstream * dz_dx, upstream * dz_dy
  z = x * y
  return z, grad
x = tf.constant(2.0, dtype=tf.float32)
y = tf.constant(3.0, dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
  z = bar(x, y)
<tf.Tensor: shape=(), dtype=float32, numpy=6.0>
tape.gradient(z, x)
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
tape.gradient(z, y)
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>

Nesting custom gradients can lead to unintuitive results. The default behavior does not correspond to n-th order derivatives. For example

def op(x):
  y = op1(x)
  def grad_fn(dy):
    gdy = op2(x, y, dy)
    def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
      return op3(x, y, dy, ddy)
    return gdy, grad_grad_fn
  return y, grad_fn

The function grad_grad_fn will be calculating the first order gradient of grad_fn with respect to dy, which is used to generate forward-mode gradient graphs from backward-mode gradient graphs, but is not the same as the second order gradient of op with respect to x.

Instead, wrap nested @tf.custom_gradients in another function: