tff.learning.framework.build_model_delta_optimizer_process

Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

model_fn a no-arg function that returns a tff.learning.Model.
model_to_client_delta_fn a function from a model_fn to a ClientDeltaFn.
server_optimizer_fn a no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
stateful_delta_aggregate_fn a tff.utils.StatefulAggregateFn where the next_fn performs a federated aggregation and updates state. That is, it has TFF type (state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER), where the value type is tff.learning.framework.ModelWeights.trainable corresponding to the object returned by model_fn.
stateful_model_broadcast_fn a tff.utils.StatefulBroadcastFn where the next_fn performs a federated broadcast and updates state. That is, it has TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS), where the value type is tff.learning.framework.ModelWeights corresponding to the object returned by model_fn.
broadcast_process a tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT).
aggregation_process a tff.templates.MeasuredProcess that aggregates the model updates on the clients back to the server. It must support the signature ({input_values}@CLIENTS-> output_values@SERVER).

A tff.templates.IterativeProcess.

ProcessTypeError if broadcast_process or aggregation_process do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER).