tff.learning.framework.build_model_delta_optimizer_process( model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn=build_stateless_mean(), stateful_model_broadcast_fn=build_stateless_broadcaster() )
for Federated Averaging or SGD.
This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.
model_fn: A no-arg function that returns a
model_to_client_delta_fn: A function from a model_fn to a
server_optimizer_fn: A no-arg function that returns a
apply_gradientsmethod of this optimizer is used to apply client updates to the server model.
tff.utils.StatefulAggregateFnwhere the next_fn performs a federated aggregation and upates state. That is, it has TFF type (state@SERVER, value@CLIENTS) -> (state@SERVER, aggregate@SERVER).
tff.utils.StatefulBroadcastFnwhere the next_fn performs a federated broadcast and upates state. That is, it has TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS).