Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

tff.learning.build_federated_averaging_process

Builds an iterative process that performs federated averaging.

Used in the notebooks

Used in the tutorials

This function creates a tff.templates.IterativeProcess that performs federated averaging on client models. The iterative process has the following methods inherited from tff.templates.IterativeProcess:

  • initialize: A tff.Computation with the functional type signature ( -> S@SERVER), where S is a tff.learning.framework.ServerState representing the initial state of the server.
  • next: A tff.Computation with the functional type signature (<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>) where S is a tff.learning.framework.ServerState whose type matches that of the output of initialize, and {B*}@CLIENTS represents the client datasets, where B is the type of a single batch. This computation returns a tff.learning.framework.ServerState representing the updated server state and aggregated metrics at the server, including client training metrics and any other metrics from broadcast and aggregation processes.

The iterative process also has the following method not inherited from tff.templates.IterativeProcess:

Each time the next method is called, the server model is broadcast to each client using a broadcast function. For each client, local training on one pass of the pre-processed client dataset (multiple epochs are possible if the dataset is pre-processed with repeat operation) is performed via the tf.keras.optimizers.Optimizer.apply_gradients method of the client optimizer. Each client computes the difference between the client model after training and the initial broadcast model. These model deltas are then aggregated at the server using some aggregation function. The aggregate model delta is applied at the server by using the tf.keras.optimizers.Optimizer.apply_gradients method of the server optimizer.

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
client_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer.
server_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer. By default, this uses tf.keras.optimizers.SGD with a learning rate of 1.0.
client_weighting A value of tff.learning.ClientWeighting that specifies a built-in weighting method, or a callable that takes the output of model.report_local_unfinalized_metrics and returns a tensor that provides the weight in the federated average of model deltas. If None, defaults to weighting by number of examples.
broadcast_process A tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT). If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast.
model_update_aggregation_factory An optional tff.aggregators.WeightedAggregationFactory or tff.aggregators.UnweightedAggregationFactory that constructs tff.templates.AggregationProcess for aggregating the client model updates on the server. If None, uses tff.aggregators.MeanFactory.
metrics_aggregator An optional function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a federated TFF computation of the following type signature local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER. If None, uses tff.learning.metrics.sum_then_finalize, which returns a federated TFF computation that sums the unfinalized metrics from CLIENTS, and then applies the corresponding metric finalizers at SERVER.
use_experimental_simulation_loop Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations.

A tff.templates.IterativeProcess.