ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


A deprecated dtype policy for a Keras layer.

Inherits From: Policy

Used in the notebooks

Used in the tutorials

The difference between this class and the non-experimental class is that this class has a loss_scale field and the non-experimental class does not. The loss scale is only used by tf.keras.Model.compile, which automatically wraps the optimizer with a LossScaleOptimizer if the optimizer is not already a LossScaleOptimizer. For the non-experimental Policy class, Model.compile instead wraps the optimizer with a LossScaleOptimizer if is "mixed_float16".

When deserializing objects with an experimental policy using functions like tf.keras.utils.deserialize_keras_object, the policy will be deserialized as the non-experimental tf.keras.mixed_precision.Policy, and the loss scale will silently be dropped. This is so that SavedModels that are generated with an experimental policy can be restored after the experimental policy is removed.

name A string. Can be one of the following values:

  • Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype.
  • 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. With 'mixed_float16', a dynamic loss scale is used. These policies are used for mixed precision training.
loss_scale A tf.compat.v1.mixed_precision.LossScale, an int (which uses a FixedLossScale), the string "dynamic" (which uses a DynamicLossScale), or None (which uses no loss scale). Defaults to "auto". In the "auto" case: 1) if name is "mixed_float16", then use loss_scale="dynamic". 2) otherwise, do not use a loss scale. Only tf.keras.Models, not layers, use the loss scale, and it is only used during, Model.train_on_batch, and other similar methods.

compute_dtype The compute dtype of this policy.

This is the dtype layers will do their computations in. Typically layers output tensors with the compute dtype as well.

Note that even if the compute dtype is float16 or bfloat16, hardware devices may not do individual adds, multiplies, and other fundamental operations in float16 or bfloat16, but instead may do some of them in float32 for numeric stability. The compute dtype is the dtype of the inputs and outputs of the TensorFlow ops that the layer executes. Internally, many TensorFlow ops will do certain internal calculations in float32 or some other device-internal intermediate format with higher precision than float16/bfloat16, to increase numeric stability.

For example, a tf.keras.layers.Dense layer, when run on a GPU with a float16 compute dtype, will pass float16 inputs to tf.linalg.matmul. But, tf.linalg.matmul will do use float32 intermediate math. The performance benefit of float16 is still apparent, due to increased memory bandwidth and the fact modern GPUs have specialized hardware for computing matmuls on float16 inputs while still keeping intermediate computations