Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tff.learning.build_federated_averaging_process

View source on GitHub

Builds an iterative process that performs federated averaging.

tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn, server_optimizer_fn=DEFAULT_SERVER_OPTIMIZER_FN,
    client_weight_fn=None, stateful_delta_aggregate_fn=None,
    stateful_model_broadcast_fn=None
)

Used in the notebooks

Used in the tutorials

This function creates a tff.utils.IterativeProcess that performs federated averaging on client models. The iterative process has the following methods:

Each time the next method is called, the server model is broadcast to each client using a broadcast function. For each client, one epoch of local training is performed via the tf.keras.optimizers.Optimizer.apply_gradients method of the client optimizer. Each client computes the difference between the client model after training and the initial broadcast model. These model deltas are then aggregated at the server using some aggregation function. The aggregate model delta is applied at the server by using the tf.keras.optimizers.Optimizer.apply_gradients method of the server optimizer.

Args:

  • model_fn: A no-arg function that returns a tff.learning.Model.
  • client_optimizer_fn: A no-arg callable that returns a tf.keras.Optimizer.
  • server_optimizer_fn: A no-arg callable that returns a tf.keras.Optimizer. By default, this uses tf.keras.optimizers.SGD with a learning rate of 1.0.
  • client_weight_fn: Optional function that takes the output of model.report_local_outputs and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device.
  • stateful_delta_aggregate_fn: A tff.utils.StatefulAggregateFn where the next_fn performs a federated aggregation and upates state. It must have TFF type (<state@SERVER, value@CLIENTS, weights@CLIENTS> -> <state@SERVER, aggregate@SERVER>), where the value type is tff.learning.framework.ModelWeights.trainable corresponding to the object returned by model_fn. By default performs arithmetic mean aggregation, weighted by client_weight_fn.
  • stateful_model_broadcast_fn: A tff.utils.StatefulBroadcastFn where the next_fn performs a federated broadcast and upates state. It must have TFF type (<state@SERVER, value@SERVER> -> <state@SERVER, value@CLIENTS>), where the value type is tff.learning.framework.ModelWeights corresponding to the object returned by model_fn. The default is the identity broadcast.

Returns:

A tff.utils.IterativeProcess.