|View source on GitHub|
Builds the TFF computations for optimization using federated SGD.
tff.learning.build_federated_sgd_process( model_fn, server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=0.1)), client_weight_fn=None, stateful_delta_aggregate_fn=None, stateful_model_broadcast_fn=None )
model_fn: A no-arg function that returns a
server_optimizer_fn: A no-arg function that returns a
apply_gradientsmethod of this optimizer is used to apply client updates to the server model.
client_weight_fn: Optional function that takes the output of
model.report_local_outputsand returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device.
next_fnperforms a federated aggregation and upates state. That is, it has TFF type
(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER), where the
tff.learning.framework.ModelWeights.trainablecorresponding to the object returned by
model_fn. By default performs arithmetic mean aggregation, weighted by
next_fnperforms a federated broadcast and upates state. That is, it has TFF type
(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS), where the
tff.learning.framework.ModelWeightscorresponding to the object returned by
model_fn. By default performs identity broadcast.