DP subclass of tf.keras.Model.

This can be used as a differentially private replacement for tf.keras.Model. This class implements DP-SGD using the standard Gaussian mechanism.

This class also utilizes a faster gradient clipping algorithm if the following two conditions hold: (i) the trainable layers of the model are keys in the input layer_registry, (ii) the loss tf.Tensor for a given batch of examples is either a scalar or a 2D tf.Tensor that has only one column (i.e., tf.shape(loss)[1] == 1) and whose i-th row corresponds to the loss of the i-th example. This clipping algorithm specifically computes clipped gradients at the per-example or per microbatch (when num_microbatches is not None) level using the layer registry functions in layer_registry (see for more information about the algorithm).

When instantiating this class, you need to supply several DP-related arguments followed by the standard arguments for Model.


# Create Model instance.
model = DPModel(l2_norm_clip=1.0, noise_multiplier=0.5, use_xla=True,
         <standard arguments>)

You should use your DPModel instance with a standard instance of tf.keras.Optimizer as the optimizer, and a standard reduced loss. You do not need to use a differentially private optimizer.

# Use a standard (non-DP) optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# Use a standard reduced loss.
loss = tf.keras.losses.MeanSquaredError()

model.compile(optimizer=optimizer, loss=loss), train_labels, epochs=1, batch_size=32)

l2_norm_clip Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier Ratio of the standard deviation to the clipping norm.
num_microbatches Number of microbatches.
use_xla If True, compiles train_step to XLA.
layer_registry A LayerRegistry instance containing functions that help compute gradient norms quickly. See tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry for more details.
*args These will be passed on to the base class __init__ method.
**kwargs These will be passed on to the base class __init__ method.