Enable mixed precision via a graph rewrite.

Mixed precision is the use of both float32 and float16 data types when training a model to improve performance. This is achieved via a graph rewrite operation and a loss-scale optimizer.

Performing arithmetic operations in float16 takes advantage of specialized processing units, such as NVIDIA Tensor Cores, for much higher arithmetic throughput. However, due to the smaller representable range, performing the entire training with float16 can result in gradient underflow, that is, small gradient values becoming zeroes. Instead, performing only select arithmetic operations in float16 results in higher throughput and decreased training time when using compatible hardware accelerators while also reducing memory usage, typically without sacrificing model accuracy.


model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(64, activation='softmax'),

opt = tf.keras.optimizers.SGD()
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)