Missed TensorFlow World? Check out the recap. Learn more

tff.learning.build_federated_averaging_process

View source on GitHub

Builds the TFF computations for optimization using federated averaging.

tff.learning.build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=1.0)),
    client_weight_fn=None,
    stateful_delta_aggregate_fn=None,
    stateful_model_broadcast_fn=None
)

Used in the tutorials:

Args:

  • model_fn: A no-arg function that returns a tff.learning.TrainableModel.
  • server_optimizer_fn: A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model. The default creates a tf.keras.optimizers.SGD with a learning rate of 1.0, which simply adds the average client delta to the server's model.
  • client_weight_fn: Optional function that takes the output of model.report_local_outputs and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device.
  • stateful_delta_aggregate_fn: A tff.utils.StatefulAggregateFn where the next_fn performs a federated aggregation and upates state. That is, it has TFF type (state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER), where the value type is tff.learning.framework.ModelWeights.trainable corresponding to the object returned by model_fn. By default performs arithmetic mean aggregation, weighted by client_weight_fn.
  • stateful_model_broadcast_fn: A tff.utils.StatefulBroadcastFn where the next_fn performs a federated broadcast and upates state. That is, it has TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS), where the value type is tff.learning.framework.ModelWeights corresponding to the object returned by model_fn. By default performs identity broadcast.

Returns:

A tff.utils.IterativeProcess.