Join TensorFlow at Google I/O, May 11-12 Register now

tff.learning.algorithms.build_weighted_fed_prox

Builds a learning process that performs the FedProx algorithm.

This function creates a tff.learning.templates.LearningProcess that performs example-weighted FedProx on client models. This algorithm behaves the same as federated averaging, except that it uses a proximal regularization term that encourages clients to not drift too far from the server model.

The iterative process has the following methods inherited from tff.learning.templates.LearningProcess:

Each time the next method is called, the server model is communicated to each client using the provided model_distributor. For each client, local training is performed using client_optimizer_fn. Each client computes the difference between the client model after training and the initial model. These model deltas are then aggregated at the server using a weighted aggregation function, according to client_weighting. The aggregate model delta is applied at the server using a server optimizer, as in the FedOpt framework proposed in Reddi et al., 2021.

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
proximal_strength A nonnegative float representing the parameter of FedProx's regularization term. When set to 0, the algorithm reduces to FedAvg. Higher values prevent clients from moving too far from the server model during local training.
client_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer.
server_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer. By default, this uses tf.keras.optimizers.SGD with a learning rate of 1.0.
client_weighting A member of tff.learning.ClientWeighting that specifies a built-in weighting method. By default, weighting by number of examples is used.
model_distributor An optional DistributionProcess that broadcasts the model weights on the server to the clients. If set to None, the distributor is constructed via distributors.build_broadcast_process.
model_aggregator An optional tff.aggregators.WeightedAggregationFactory used to aggregate client updates on the server. If None, this is set to tff.aggregators.MeanFactory.
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize.
use_experimental_simulation_loop Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations.

A tff.learning.templates.LearningProcess.

ValueError If proximal_parameter is not a nonnegative float.