|View on TensorFlow.org||Run in Google Colab||View source on GitHub||Download notebook|
tff.learning module contains a number of ways to aggregate model udpates with recommended default configuration:
In this tutorial, we explain the underlying motivation, how they are implemented, and provide suggestions for how to customize their configuration.
!pip install --quiet --upgrade tensorflow-federated-nightly !pip install --quiet --upgrade nest-asyncio import nest_asyncio nest_asyncio.apply()
import math import tensorflow_federated as tff tff.federated_computation(lambda: 'Hello, World!')()
Aggregation methods are represented by objects that can be passed to
tff.learning.build_federated_averaging_process as its
model_update_aggregation_factory keyword argument. As such, the aggregators discussed here can be directly used to modify a previous tutorial on federated learning.
mean = tff.aggregators.MeanFactory() iterative_process = tff.learning.build_federated_averaging_process( ..., model_update_aggregation_factory=mean)
The techniques which can be used to extend the weighted mean covered in this tutorial are:
- Differential Privacy
- Secure Aggregation
The extension is done using composition, in which the
MeanFactory wraps an inner factory to which it delegates some part of the aggregation, or is itself wrapped by another aggregation factory. For more detail on the design, see Implementing custom aggregators tutorial.
First, we will explain how to enable and configure these techniques individually, and then show how they can be combined together.
Before delving into the individual techniques, we first introduce quantile matching algorithm, which will be useful for configuring the techniques below.
Several of the aggregation techniques below need to use a norm bound, that controls some aspect of the aggregation. Such bounds can be provided as a constant, but usually it is better to adapt the bound during the course of training. The recommended way is to use the quantile matching algorithm of Thakkar et al. (2019) initially proposed for its compatibility with differential privacy but useful more broadly. To estimate the value at a given quantile, you can use
tff.aggregators.PrivateQuantileEstimationProcess. For example, to adapt to the median of a distribution, you can use:
median_estimate = tff.aggregators.PrivateQuantileEstimationProcess.no_noise( initial_estimate=1.0, target_quantile=0.5, learning_rate=0.2)
Different techinques which use the quantile estimation algorithm will require different values of the algorithm parameters, as we will see. In general, increasing the
learning_rate parameter means faster adaptation to the correct quantile, but with a higher variance. The
no_noise classmethod constructs a quantile matching process that does not add noise for differential privacy.
Zeroing refers to replacing unusually large values by zeros. Here, "unusually large" could mean larger than a predefined threshold, or large relative to values from previous rounds of the computation. Zeroing can increase system robustness to data corruption on faulty clients.
zeroing_mean = tff.aggregators.zeroing_factory( zeroing_norm=MY_ZEROING_CONSTANT, inner_agg_factory=tff.aggregators.MeanFactory())
Here we wrap a
MeanFactory with a
zeroing_factory because we want the (pre-aggregation) effects of the
zeroing_factory to apply to the values at clients before they are passed to the inner
MeanFactory for aggregation via averaging.
However, for most applications we recommend adaptive zeroing with the quantile estimator. To do so, we use the quantile matching algorithm as follows:
zeroing_norm = tff.aggregators.PrivateQuantileEstimationProcess.no_noise( initial_estimate=10.0, target_quantile=0.98, learning_rate=math.log(10), multiplier=2.0, increment=1.0) zeroing_mean = tff.aggregators.zeroing_factory( zeroing_norm=zeroing_norm, inner_agg_factory=tff.aggregators.MeanFactory()) # Equivalent to: # zeroing_mean = tff.learning.robust_aggregator(clipping=False)
The parameters have been chosen so that the process adapts very quickly (relatively large
learning_rate) to a value somewhat larger than the largest values seen so far. For a quantile estimate
Q, the threshold used for zeroing will be
Q * multiplier + increment.
Clipping to bound L2 norm
Clipping client updates (projecting onto an L2 ball) can improve robustness to outliers. A
tff.aggregators.clipping_factory is structured exactly like
tff.aggregators.zeroing_factory discussed above, and can take either a constant or a
tff.templates.EstimationProcess as its
clipping_norm argument. The recommended best practice is to use clipping that adapts moderately quickly to a moderately high norm, as follows:
clipping_norm = tff.aggregators.PrivateQuantileEstimationProcess.no_noise( initial_estimate=1.0, target_quantile=0.8, learning_rate=0.2) clipping_mean = tff.aggregators.clipping_factory( clipping_norm=clipping_norm, inner_agg_factory=tff.aggregators.MeanFactory()) # Equivalent to: # clipping_mean = tff.learning.robust_aggregator(zeroing=False)
In our experience over many problems, the precise value of
target_quantile does not seem to matter too much so long as learning rates are tuned appropriately. However, setting it very low may require increasing the server learning rate for best performance, relative to not using clipping, which is why we recommend 0.8 by default.
TODO(b/168583108): Write this section.
Compared to lossless compression such as gzip, lossy compression generally results in a much higher compression ratio and can still be combined with lossless compression afterwards. Since less time needs to be spent on client-to-server communication, training rounds complete faster. Due to the inherently randomized nature of learning algorithms, up to some threshold, the inaccuracy from lossy compression does not have negative impact on the overall performance.
The default recommendation is to use simple uniform quantization (see Suresh et al. for instance), parameterized by two values: the tensor size compression
threshold and the number of
quantization_bits. For every tensor
t, if the number of elements of
t is less or equal to
threshold, it is not compressed. If it is larger, the elements of
t are quantized using randomized rounding to
quantizaton_bits bits. That is, we apply the operation
t = round((t - min(t)) / (max(t) - min(t)) * (2**quantizaton_bits - 1)),
resulting in integer values in the range of
[0, 2**quantizaton_bits-1]. The quantized values are directly packed into an integer type for transmission, and then the inverse transformation is applied.
We recommend setting
quantizaton_bits equal to 8 and
threshold equal to 20000:
compressed_mean = tff.aggregators.MeanFactory( tff.aggregators.EncodedSumFactory.quantize_above_threshold( quantization_bits=8, threshold=20000)) # Equivalent to: # compressed_mean = tff.learning.compression_aggregator(zeroing=False, clipping=False)
threshold can be adjusted, and the number of clients participating in each training round can also impact the effectiveness of compression.
Threshold. The default value of 20000 is chosen because we have observed that variables with small number of elements, such as biases in common layer types, are much more sensitive to introduced noise. Moreover, there is little to be gained from compressing variables with small number of elements in practice, as their uncompressed size is relatively small to begin with.
In some applications it may make sense to change the choice of threshold. For instance, the biases of the output layer of a classification model may be more sensitive to noise. If you are training a language model with a vocabulary of 20004, you may want to set
threshold to be 20004.
Quantization bits. The default value of 8 for
quantization_bits should be fine for most users. If 8 is working well and you want to squeeze out a bit more performance, you could try taking it down to 7 or 6. If resources permit doing a small grid search, we would recommend that you identify the value for which training becomes unstable or final model quality starts to degrade, and then increase that value by two. For example, if setting
quantization_bits to 5 works, but setting it to 4 degrades the model, we would recommend the default to be 6 to be "on the safe side".
Clients per round. Note that significantly increasing the number of clients per round can enable a smaller value for
quantization_bits to work well, because the randomized inaccuracy introduced by quantization may be evened out by averaging over more client updates.
By Secure Aggregation (SecAgg) we refer to a cryptographic protocol wherein client updates are encrypted in such a way that the server can only decrypt their sum. If the number of clients that report back is insufficient, the server will learn nothing at all -- and in no case will the server be able to inspect individual updates. This is realized using the
The model updates are floating point values, but SecAgg operates on integers. Therefore we need to clip any large values to some bound before discretization to an integer type. The clipping bound can be either a constant or determined adaptively (the recommended default). The integers are then securely summed, and the sum is mapped back to the floating point domain.
To compute a mean with weighted values summed using SecAgg with
MY_SECAGG_BOUND as the clipping bound, pass
secure_mean = tff.aggregators.MeanFactory( tff.aggregators.SecureSumFactory(MY_SECAGG_BOUND))
To do the same while determining bounds adaptively:
secagg_bound = tff.aggregators.PrivateQuantileEstimationProcess.no_noise( initial_estimate=50.0, target_quantile=0.95, learning_rate=1.0, multiplier=2.0) secure_mean = tff.aggregators.MeanFactory( tff.aggregators.SecureSumFactory(secagg_bound)) # Equivalent to: # secure_mean = tff.learning.secure_aggregator(zeroing=Fasle, clipping=False)
The adaptive parameters have been chosen so that the bounds are tight (we won't lose much precision in discretization) but clipping happens rarely.
If tuning the parameters, keep in mind that the SecAgg protocol is summing the weighted model updates, after weighting in the mean. The weights are typically the number of data points processed locally, hence between different tasks, the right bound might depend on this quantity.
We do not recommend using the
increment keyword argument when creating adaptive
secagg_bound, as this could result in a large relative precision loss, in the case the actual estimate ends up being small.
The above code snippet will use SecAgg only the weighted values. If SecAgg should be also used for the sum of weights, we recommend the bounds to be set as constants, as in a common training setup, the largest possible weight will be known in advance:
secure_mean = tff.aggregators.MeanFactory( value_sum_factory=tff.aggregators.SecureSumFactory(secagg_bound), weight_sum_factory=tff.aggregators.SecureSumFactory( upper_bound_threshold=MAX_WEIGHT, lower_bound_threshold=0.0))
Individual techniques for extending a mean introduced above can be combined together.
We recommend the order in which these techniques are applied at clients to be
- Other techniques
The aggregators in
tff.aggregators module are composed by wrapping "inner aggregators" (whose pre-aggregation effects happen last and post-aggregation effects happen first) inside "outer aggregators". For example, to perform zeroing, clipping, and compression (in that order), one would write:
# Compression is innermost because its pre-aggregation effects are last. compressed_mean = tff.aggregators.MeanFactory( tff.aggregators.EncodedSumFactory.quantize_above_threshold( quantization_bits=8, threshold=20000)) # Compressed mean is inner aggregator to clipping... clipped_compressed_mean = tff.aggregators.clipping_factory( clipping_norm=MY_CLIPPING_CONSTANT, inner_agg_factory=compressed_mean) # ...which is inner aggregator to zeroing, since zeroing happens first. final_aggregator = tff.aggregators.zeroing_factory( zeroing_norm=MY_ZEROING_CONSTANT, inner_agg_factory=clipped_compressed_mean)
Note that this structure matches the default aggregators for learning algorithms.
Other compositions are possible, too. We extend this document when we are confident that we can provide default configuration which works in multiple different applications. For implementing new ideas, see Implementing custom aggregators tutorial.