Defines a function as a recomputecheckpoint for the tape autodiff.
tf.recompute_grad(
f
)
Tape checkpointing is a technique to reduce the memory consumption of the autodiff tape:
Without tape checkpointing operations and intermediate values are recorded to the tape for use in the backward pass.
With tape checkpointing, only the function call and its inputs are recorded. During backpropagation the
recompute_grad
custom gradient (tf.custom_gradient
) recomputes the function under a localized Tape object. This recomputation of the function during backpropagation performs redundant calculation, but reduces the overall memory usage of the Tape.
y = tf.Variable(1.0)
def my_function(x):
tf.print('running')
z = x*y
return z
my_function_recompute = tf.recompute_grad(my_function)
with tf.GradientTape() as tape:
r = tf.constant(1.0)
for i in range(4):
r = my_function_recompute(r)
running
running
running
running
grad = tape.gradient(r, [y])
running
running
running
running
Without recompute_grad
, the tape contains all intermitate steps, and no
recomputation is performed.
with tf.GradientTape() as tape:
r = tf.constant(1.0)
for i in range(4):
r = my_function(r)
running
running
running
running
grad = tape.gradient(r, [y])
If f
was a tf.keras
Model
or Layer
object, methods and attributes
such as f.variables
are not available on the returned function g
.
Either keep a reference of f
, or use g.__wrapped__
for accessing
these variables and methods.
def print_running_and_return(x):
tf.print("running")
return x
model = tf.keras.Sequential([
tf.keras.layers.Lambda(print_running_and_return),
tf.keras.layers.Dense(2)
])
model_recompute = tf.recompute_grad(model)
with tf.GradientTape(persistent=True) as tape:
r = tf.constant([[1,2]])
for i in range(4):
r = model_recompute(r)
running
running
running
running
grad = tape.gradient(r, model.variables)
running
running
running
running
Alternatively, use the __wrapped__
attribute to access the original
model object.
grad = tape.gradient(r, model_recompute.__wrapped__.variables)
running
running
running
running
Args  

f

function f(*x) that returns a Tensor or sequence of Tensor outputs.

Returns  

A function g wrapping f that defines a custom gradient, which recomputes
f on the backwards pass of a gradient call.
