RSVP for your your local TensorFlow Everywhere event today!


Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

model_fn A no-arg function that returns a tff.learning.Model.
model_to_client_delta_fn A function from a model_fn to a ClientDeltaFn.
server_optimizer_fn A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
broadcast_process A tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT).
aggregation_process A tff.templates.MeasuredProcess that aggregates the model updates on the clients back to the server. It must support the signature ({input_values}@CLIENTS-> output_values@SERVER). Must be None if model_update_aggregation_factory is not None.
model_update_aggregation_factory An optional tff.aggregators.WeightedAggregationFactory that contstructs tff.templates.AggregationProcess for aggregating the client model updates on the server. If None, uses a default constructed tff.aggregators.MeanFactory, creating a stateless mean aggregation. Must be None if aggregation_process is not None.

A tff.templates.IterativeProcess.

ProcessTypeError if broadcast_process or aggregation_process do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER).
DisjointArgumentError if both aggregation_process and model_update_aggregation_factory are not None.