tff.learning.framework.build_model_delta_optimizer_process

View source on GitHub

Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

model_fn A no-arg function that returns a tff.learning.Model.
model_to_client_delta_fn A function from a model_fn to a ClientDeltaFn.
server_optimizer_fn A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
stateful_delta_aggregate_fn A tff.utils.StatefulAggregateFn where the next_fn performs a federated aggregation and upates state. That is, it has TFF type (state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER), where the value type is tff.learning.framework.ModelWeights.trainable corresponding to the object returned by model_fn.
stateful_model_broadcast_fn A tff.utils.StatefulBroadcastFn where the next_fn performs a federated broadcast and upates state. That is, it has TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS), where the value type is tff.learning.framework.ModelWeights corresponding to the object returned by model_fn.

A tff.templates.IterativeProcess.