Missed TensorFlow World? Check out the recap. Learn more

tf.train.experimental.enable_mixed_precision_graph_rewrite

View source on GitHub

Enable mixed precision via a graph rewrite.

Aliases:

  • tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite
tf.train.experimental.enable_mixed_precision_graph_rewrite(
    opt,
    loss_scale='dynamic'
)

Mixed precision is the use of both float16 and float32 when training a model, and is used to make the model run faster. This function will use mixed precision to speed up the execution time of your model when run on a GPU. It does this by changing the dtype of certain operations in the graph from float32 to float16.

This function additionally wraps an Optimizer with a LossScaleOptimizer, which is required to prevent underflow in the float16 tensors during the backwards pass. An optimizer must be passed to this function, which will then be wrapped to use loss scaling.

When this function is used, gradients should only be computed and applied with the returned optimizer, either by calling opt.minimize() or opt.compute_gradients() followed by opt.apply_gradients(). Gradients should not be computed with tf.gradients or tf.GradientTape. This is because the returned optimizer will apply loss scaling, and tf.gradients/tf.GradientTape will not. If you do directly use tf.gradients or tf.GradientTape, your model may train to a worse quality.

When eager execution is enabled, the mixed precision graph rewrite is only enabled within tf.functions, as outside tf.functions, there is no graph.

When enabled, mixed precision is only used on Volta GPUs and above. The parts of the graph on CPUs and TPUs are untouched by the graph rewrite.

Args:

Returns:

A version of opt that will use loss scaling to prevent underflow.