Watch talks from the 2019 TensorFlow Dev Summit Watch now

tff.learning.build_federated_averaging_process

tff.learning.build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=(lambda : gradient_descent.SGD(learning_rate=1.0)),
    client_weight_fn=None
)

Defined in learning/federated_averaging.py.

Builds the TFF computations for optimization using federated averaging.

Args:

  • model_fn: A no-arg function that returns a tff.learning.TrainableModel.
  • server_optimizer_fn: A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model. The default creates a tf.keras.optimizers.SGD with a learning rate of 1.0, which simply adds the average client delta to the server's model.
  • client_weight_fn: Optional function that takes the output of model.report_local_outputs and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device.

Returns:

A tff.utils.IterativeProcess.