tff.learning.build_federated_sgd_process

Builds the TFF computations for optimization using federated SGD.

This function creates a tff.templates.IterativeProcess that performs federated averaging on client models. The iterative process has the following methods:

Each time the next method is called, the server model is broadcast to each client using a broadcast function. Each client sums the gradients at each batch in the client's local dataset. These gradient sums are then aggregated at the server using an aggregation function. The aggregate gradients are applied at the server by using the tf.keras.optimizers.Optimizer.apply_gradients method of the server optimizer.

This implements the original FedSGD algorithm in McMahan et al., 2017.

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
server_optimizer_fn A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
client_weight_fn Optional function that takes the output of model.report_local_outputs and returns a tensor that provides the weight in the federated average of the aggregated gradients. If not provided, the default is the total number of examples processed on device.
stateful_delta_aggregate_fn A tff.utils.StatefulAggregateFn where the next_fn performs a federated aggregation and upates state. It must have TFF type (<state@SERVER, value@CLIENTS, weights@CLIENTS> -> <state@SERVER, aggregate@SERVER>), where the value type is tff.learning.framework.ModelWeights.trainable corresponding to the object returned by model_fn. By default performs arithmetic mean aggregation, weighted by client_weight_fn. Must be None if aggregation_process is not None.
stateful_model_broadcast_fn A tff.utils.StatefulBroadcastFn where the next_fn performs a federated broadcast and upates state. It must have TFF type (<state@SERVER, value@SERVER> -> <state@SERVER, value@CLIENTS>), where the value type is tff.learning.framework.ModelWeights corresponding to the object returned by model_fn. The default is the identity broadcast. Must be None if broadcast_process is not None.
broadcast_process a tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT). Must be None if stateful_model_broadcast_fn is not None.
aggregation_process a tff.templates.MeasuredProcess that aggregates the model updates on the clients back to the server. It must support the signature ({input_values}@CLIENTS-> output_values@SERVER). Must be None if stateful_delta_aggregate_fn is not None.

A tff.templates.IterativeProcess.