![]() |
Constructs tff.templates.IterativeProcess
for Federated Averaging or SGD.
tff.learning.framework.build_model_delta_optimizer_process(
model_fn: tff.learning.Model
,
model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]], ClientDeltaFn],
server_optimizer_fn: tff.learning.optimizers.Optimizer
,
*,
broadcast_process: Optional[tff.templates.MeasuredProcess
] = None,
model_update_aggregation_factory: Optional[factory.AggregationFactory] = None,
metrics_aggregator: Optional[Callable[[model_lib.MetricFinalizersType, computation_types.
StructWithPythonType], computation_base.Computation]] = None
) -> tff.templates.IterativeProcess
This provides the TFF orchestration logic connecting the common server logic
which applies aggregated model deltas to the server model with a
ClientDeltaFn
that specifies how weight_deltas
are computed on device.
Args | |
---|---|
model_fn
|
A no-arg function that returns a tff.learning.Model .
|
model_to_client_delta_fn
|
A function from a model_fn to a ClientDeltaFn .
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer or a no-arg
function that returns a tf.keras.optimizers.Optimizer .
|
broadcast_process
|
A tff.templates.MeasuredProcess that broadcasts the
model weights on the server to the clients. It must support the signature
(input_values@SERVER -> output_values@CLIENT) . If set to default None,
the server model is broadcast to the clients using the default
tff.federated_broadcast.
|
model_update_aggregation_factory
|
An optional
tff.aggregators.WeightedAggregationFactory or
tff.aggregators.UnweightedAggregationFactory that contstructs
tff.templates.AggregationProcess for aggregating the client model
updates on the server. If None , uses tff.aggregators.MeanFactory .
|
metrics_aggregator
|
An optional function that takes in the metric finalizers
(i.e., tff.learning.Model.metric_finalizers() ) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of tff.learning.Model.report_local_unfinalized_metrics() ), and
returns a federated TFF computation of the following type signature
local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER . If
None , uses tff.learning.metrics.sum_then_finalize , which returns a
federated TFF computation that sums the unfinalized metrics from
CLIENTS , and then applies the corresponding metric finalizers at
SERVER .
|
Returns | |
---|---|
A tff.templates.IterativeProcess .
|
Raises | |
---|---|
ProcessTypeError
|
If broadcast_process does not conform to the signature
of broadcast (SERVER->CLIENTS).
|