|View source on GitHub|
Builds the TFF computations for optimization using federated averaging.
tff.learning.build_federated_averaging_process( model_fn, server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=1.0)), client_weight_fn=None, stateful_delta_aggregate_fn=None, stateful_model_broadcast_fn=None )
Used in the tutorials:
- Federated Learning for Image Classification
- Federated Learning for Text Generation
- High-performance simulations with TFF
model_fn: A no-arg function that returns a
server_optimizer_fn: A no-arg function that returns a
apply_gradientsmethod of this optimizer is used to apply client updates to the server model. The default creates a
tf.keras.optimizers.SGDwith a learning rate of 1.0, which simply adds the average client delta to the server's model.
client_weight_fn: Optional function that takes the output of
model.report_local_outputsand returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device.
next_fnperforms a federated aggregation and upates state. That is, it has TFF type
(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER), where the
tff.learning.framework.ModelWeights.trainablecorresponding to the object returned by
model_fn. By default performs arithmetic mean aggregation, weighted by
next_fnperforms a federated broadcast and upates state. That is, it has TFF type
(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS), where the
tff.learning.framework.ModelWeightscorresponding to the object returned by
model_fn. By default performs identity broadcast.