tf.keras.mixed_precision.experimental.Policy

A dtype policy for a Keras layer.

Used in the notebooks

Used in the guide

A dtype policy determines dtype-related aspects of a layer, such as its computation and variable dtypes. Each layer has a policy. Policies can be passed to the dtype argument of layer constructors, or a global policy can be set with tf.keras.mixed_precision.experimental.set_policy. A layer will default to the global policy if no policy is passed to it's constructor.

For many models, each layer's policy will have the same compute dtype and variable dtype, which will typically be float32. In this case, we refer to the singular dtype as the layer's dtype, which can be queried by the property tf.keras.layers.Layer.dtype.

When mixed precision training is used, most layers will instead have a float16 or bfloat16 compute dtype and a float32 variable dtype, and so the layer does not have a single dtype. When the variable dtype does not match the compute dtype, variables will be automatically casted to the compute dtype to avoid type errors. In this case, tf.keras.layers.Layer.dtype refers to the variable dtype, not the compute dtype. See the mixed precision guide for more information on how to use mixed precision.

Certain policies also have a tf.mixed_precision.experimental.LossScale instance, which is used by tf.keras.Models to performance loss scaling. Loss scaling is a technique used with mixed precision to avoid numerical underflow in float16 gradients. Loss scaling is only done by Models in Model.fit, Model.train_on_batch, and similar methods. Layers which are not Models ignore the loss scale.

Policies are constructed by passing a string to the constructor, e.g. tf.keras.mixed_precision.experimental.Policy('float32'). The string determines the compute and variable dtypes. It can be one of the following:

  • Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. No loss scaling is done by default.
  • 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. These policies are used for mixed precision training. With 'mixed_float16', a dynamic loss scale is used by default. 'mixed_bfloat16' does no loss scaling by default, as loss scaling is unnecessary with bfloat16.

How to use mixed precision in a Keras model

To use mixed precision in a Keras model, the 'mixed_float16' or