tff.learning.framework.build_model_delta_optimizer_process

Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

model_fn A no-arg function that returns a tff.learning.Model.
model_to_client_delta_fn A function from a model_fn to a ClientDeltaFn.
server_optimizer_fn A tff.learning.optimizers.Optimizer or a no-arg function that returns a tf.keras.optimizers.Optimizer.
broadcast_process A tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT). If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast.
model_update_aggregation_factory An optional tff.aggregators.WeightedAggregationFactory or tff.aggregators.UnweightedAggregationFactory that contstructs tff.templates.AggregationProcess for aggregating the client model updates on the server. If None, uses tff.aggregators.MeanFactory.
metrics_aggregator An optional function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a federated TFF computation of the following type signature local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER. If None, uses tff.learning.metrics.sum_then_finalize, which returns a federated TFF computation that sums the unfinalized metrics from CLIENTS, and then applies the corresponding metric finalizers at SERVER.

A tff.templates.IterativeProcess.

ProcessTypeError If broadcast_process does not conform to the signature of broadcast (SERVER->CLIENTS).