Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

Returns a DPOptimizerClass cls using the TreeAggregationQuery.

Combining this query with a SGD optimizer can be used to implement the DP-FTRL algorithm in "Practical and Private (Deep) Learning without Sampling or Shuffling".

This function is a thin wrapper around make_keras_optimizer_class.<locals>.DPOptimizerClass which can be used to apply a TreeAggregationQuery to any DPOptimizerClass.

l2_norm_clip Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier Ratio of the standard deviation to the clipping norm.
var_list_or_model Either a tf.keras.Model or a list of tf.variables from which tf.TensorSpecs can be defined. These specify the structure and shapes of records (gradients).
num_microbatches Number of microbatches into which each minibatch is split. Default is None which means that number of microbatches is equal to batch size (i.e. each microbatch contains exactly one example). If gradient_accumulation_steps is greater than 1 and num_microbatches is not None then the effective number of microbatches is equal to num_microbatches * gradient_accumulation_steps.
gradient_accumulation_steps If greater than 1 then optimizer will be accumulating gradients for this number of optimizer steps before applying them to update model weights. If this argument is set to 1 then updates will be applied on each optimizer step.
restart_period (Optional) Restart wil occur after restart_period steps. The default (None) means there will be no periodic restarts. Must be a positive integer. If restart_warmup is passed, this only applies to the second restart and onwards and must be not None.
restart_warmup (Optional) The first restart will occur after restart_warmup steps. The default (None) means no warmup. Must be an integer in the range [1, restart_period - 1].
noise_seed (Optional) Integer seed for the Gaussian noise generator. If None, a nondeterministic seed based on system time will be generated.
*args These will be passed on to the base class __init__ method.
**kwargs These will be passed on to the base class __init__ method.

ValueError If restart_warmup is not None and restart_period is None.