tff.learning.framework.build_model_delta_optimizer_process

Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

model_fn A no-arg function that returns a tff.learning.Model.
model_to_client_delta_fn A function from a model_fn to a ClientDeltaFn.
server_optimizer_fn A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
broadcast_process A tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT).
aggregation_process A tff.templates.MeasuredProcess that aggregates the model updates on the clients back to the server. It must support the signature ({input_values}@CLIENTS-> output_values@SERVER). Must be None if model_update_aggregation_factory is not None.
model_update_aggregation_factory An optional tff.aggregators.AggregationProcessFactory that contstructs tff.templates.AggregationProcess for aggregating the client model updates on the server. If None, uses a default constructed tff.aggregators.MeanFactory, creating a stateless mean aggregation. Must be None if aggregation_process is not None.

A tff.templates.IterativeProcess.

ProcessTypeError if broadcast_process or aggregation_process do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER).
DisjointArgumentError if both aggregation_process and model_update_aggregation_factory are not None.