tf.clip_by_norm

Clips tensor values to a maximum L2-norm.

Used in the notebooks

Used in the guide

Given a tensor t, and a maximum clip value clip_norm, this operation normalizes t so that its L2-norm is less than or equal to clip_norm, along the dimensions given in axes. Specifically, in the default case where all dimensions are used for calculation, if the L2-norm of t is already less than or equal to clip_norm, then t is not modified. If the L2-norm is greater than clip_norm, then this operation returns a tensor of the same type and shape as t with its values set to:

t * clip_norm / l2norm(t)

In this case, the L2-norm of the output tensor is clip_norm.

As another example, if t is a matrix and axes == [1], then each row of the output will have L2-norm less than or equal to clip_norm. If axes == [0] instead, each column of the output will be clipped.

Code example:

some_nums = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.float32)
tf.clip_by_norm(some_nums, 2.0).numpy()
array([[0.26967996, 0.5393599 , 0.80903983, 1.0787199 , 1.3483998 ]],
      dtype=float32)

This operation is typically used to clip gradients before applying them with an optimizer. Most gradient data is a collection of different shaped tensors for different parts of the model. Thus, this is a common usage:

# Get your gradients after training
loss_value, grads = grad(model, features, labels)

# Apply some clipping
grads = [tf.clip_by_norm(g, norm)
             for g in grads]

# Continue on with training
optimizer.apply_gradients(grads)

t A Tensor or IndexedSlices. This must be a floating point type.
clip_norm A 0-D (scalar) Tensor > 0. A maximum clipping value, also floating point. Note: If a negative clip_norm is provided, it will be treated as zero.
axes A 1-D (vector) Tensor of type int32 containing the dimensions to use for computing the L2-norm. If None (the default), uses all dimensions.
name A name for the operation (optional).

A clipped Tensor or IndexedSlices.

ValueError If the clip_norm tensor is not a 0-D scalar tensor.
TypeError If dtype of the input is not a floating point or complex type.