tf.recompute_grad

Defines a function as a recompute-checkpoint for the tape auto-diff.

Tape checkpointing is a technique to reduce the memory consumption of the auto-diff tape:

  • Without tape checkpointing operations and intermediate values are recorded to the tape for use in the backward pass.

  • With tape checkpointing, only the function call and its inputs are recorded. During back-propagation the recompute_grad custom gradient (tf.custom_gradient) recomputes the function under a localized Tape object. This recomputation of the function during backpropagation performs redundant calculation, but reduces the overall memory usage of the Tape.

y = tf.Variable(1.0)
def my_function(x):
  tf.print('running')
  z = x*y
  return z
my_function_recompute = tf.recompute_grad(my_function)
with tf.GradientTape() as tape:
  r = tf.constant(1.0)
  for i in range(4):
    r = my_function_recompute(r)
running
running
running
running
grad = tape.gradient(r, [y])
running
running
running
running

Without recompute_grad, the tape contains all intermitate steps, and no recomputation is performed.

with tf.GradientTape() as tape:
  r = tf.constant(1.0)
  for i in range(4):
    r = my_function(r)
running
running
running
running
grad = tape.gradient(r, [y])

If f was a tf.keras Model or Layer object, methods and attributes such as f.variables are not available on the returned function g. Either keep a reference of f , or use g.__wrapped__ for accessing these variables and methods.

def print_running_and_return(x):
  tf.print("running")
  return x
model = tf.keras.Sequential([
  tf.keras.layers.Lambda(print_running_and_return),
  tf.keras.layers.Dense(2)
])
model_recompute = tf.recompute_grad(model)
with tf.GradientTape(persistent=True) as tape:
  r = tf.constant([[1,2]])
  for i in range(4):
    r = model_recompute(r)
running
running
running
running
grad = tape.gradient(r, model.variables)
running
running
running
running

Alternatively, use the __wrapped__ attribute to access the original model object.

grad = tape.gradient(r, model_recompute.__wrapped__.variables)
running
running
running
running

f function f(*x) that returns a Tensor or sequence of Tensor outputs.

A function g wrapping f that defines a custom gradient, which recomputes f on the backwards pass of a gradient call.