tf.keras.mixed_precision.Policy

A dtype policy for a Keras layer.

Used in the notebooks

Used in the guide

A dtype policy determines a layer's computation and variable dtypes. Each layer has a policy. Policies can be passed to the dtype argument of layer constructors, or a global policy can be set with tf.keras.mixed_precision.set_global_policy.

name The policy name, which determines the compute and variable dtypes. Can be any dtype name, such as 'float32' or 'float64', which causes both the compute and variable dtypes will be that dtype. Can also be the string 'mixed_float16' or 'mixed_bfloat16', which causes the compute dtype to be float16 or bfloat16 and the variable dtype to be float32.

Typically you only need to interact with dtype policies when using mixed precision, which is the use of float16 or bfloat16 for computations and float32 for variables. This is why the term mixed_precision appears in the API name. Mixed precision can be enabled by passing 'mixed_float16' or 'mixed_bfloat16' to tf.keras.mixed_precision.set_global_policy. See the mixed precision guide for more information on how to use mixed precision.

tf.keras.mixed_precision.set_global_policy('mixed_float16')
layer1 = tf.keras.layers.Dense(10)
layer1.dtype_policy  # `layer1` will automatically use mixed precision
<Policy "mixed_float16">
# Can optionally override layer to use float32 instead of mixed precision.
layer2 = tf.keras.layers.Dense(10, dtype='float32')
layer2.dtype_policy
<Policy "float32">
# Set policy back to initial float32 for future examples.
tf.keras.mixed_precision.set_global_policy('float32')

In the example above, passing dtype='float32' to the layer is equivalent to passing dtype=tf.keras.mixed_precision.Policy('float32'). In general, passing a dtype policy name to a layer is equivalent to passing the corresponding policy, so it is never necessary to explicitly construct a Policy object.

How a layer uses its policy's compute dtype

A layer casts its inputs to its compute dtype. This causes the layer's computations and output to also be in the compute dtype. For example:

x = tf.ones((4, 4, 4, 4), dtype='float64')
# `layer`'s policy defaults to float32.
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
layer.compute_dtype  # Equivalent to layer.dtype_policy.compute_dtype
'float32'
# `layer` casts its inputs to its compute dtype and does computations in
# that dtype.
y = layer(x)
y.dtype
tf.float32

Note that the base tf.keras.layers.Layer class inserts the casts. If subclassing your own layer, you do not have to insert any casts.

Currently, only tensors in the first argument to the layer's call method are casted (although this will likely be changed in a future minor release). For example:

class MyLayer(tf.keras.layers.Layer):
  # Bug! `b` will not be casted.
  def call(self, a, b):
    return a + 1., b + 1.
a = tf.constant(1., dtype="float32")
b = tf.constant(1., dtype="float32")
layer = MyLayer(dtype="float64")
x, y = layer(a, b)
x.dtype
tf.float64
y.dtype
tf.float32

If writing your own layer with multiple inputs, you should either explicitly cast other tensors to self.compute_dtype in call or accept all tensors in the first argument as a list.

The casting only occurs in TensorFlow 2. If