![]() |
Builds a learning process that performs federated averaging.
tff.learning.algorithms.build_weighted_fed_avg(
model_fn: Callable[[], tff.learning.Model
],
client_optimizer_fn: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]],
server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
client_weighting: Optional[tff.learning.ClientWeighting
] = tff.learning.ClientWeighting.NUM_EXAMPLES
,
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory
] = None,
metrics_aggregator: Optional[Callable[[model_lib.MetricFinalizersType, computation_types.
StructWithPythonType], computation_base.Computation]] = None,
use_experimental_simulation_loop: bool = False
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess
that performs
federated averaging on client models. The iterative process has the following
methods inherited from tff.learning.templates.LearningProcess
:
initialize
: Atff.Computation
with the functional type signature( -> S@SERVER)
, whereS
is atff.learning.templates.LearningAlgorithmState
representing the initial state of the server.next
: Atff.Computation
with the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
and{B*}@CLIENTS
represents the client datasets. The outputL
contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes.get_model_weights
: Atff.Computation
with type signature(S -> M)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andnext
andM
represents the type of the model weights used during training.
Each time the next
method is called, the server model is communicated to
each client using the provided model_distributor
. For each client, local
training is performed using client_optimizer_fn
. Each client computes the
difference between the client model after training and its initial model.
These model deltas are then aggregated at the server using a weighted
aggregation function, according to client_weighting
. The aggregate model
delta is applied at the server using a server optimizer.
Args | |
---|---|
model_fn
|
A no-arg function that returns a tff.learning.Model . This method
must not capture TensorFlow tensors or variables and use them. The model
must be constructed entirely from scratch on each invocation, returning
the same pre-constructed model each call will result in an error.
|
client_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
callable that returns a tf.keras.Optimizer .
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
callable that returns a tf.keras.Optimizer . By default, this uses
tf.keras.optimizers.SGD with a learning rate of 1.0.
|
client_weighting
|
A member of tff.learning.ClientWeighting that specifies
a built-in weighting method. By default, weighting by number of examples
is used.
|
model_distributor
|
An optional DistributionProcess that distributes the
model weights on the server to the clients. If set to None , the
distributor is constructed via distributors.build_broadcast_process .
|
model_aggregator
|
An optional tff.aggregators.WeightedAggregationFactory
used to aggregate client updates on the server. If None , this is set to
tff.aggregators.MeanFactory .
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.Model.metric_finalizers() ) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of tff.learning.Model.report_local_unfinalized_metrics() ), and
returns a tff.Computation for aggregating the unfinalized metrics. If
None , this is set to tff.learning.metrics.sum_then_finalize .
|
use_experimental_simulation_loop
|
Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. |
Returns | |
---|---|
A tff.learning.templates.LearningProcess .
|