![]() |
Builds an iterative process that performs federated averaging.
tff.learning.build_federated_averaging_process(
model_fn: Callable[[], tff.learning.Model
],
client_optimizer_fn: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]],
server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
*,
client_weighting: Optional[tff.learning.ClientWeighting
] = None,
broadcast_process: Optional[tff.templates.MeasuredProcess
] = None,
model_update_aggregation_factory: Optional[factory.AggregationFactory] = None,
metrics_aggregator: Optional[Callable[[model_lib.MetricFinalizersType, computation_types.
StructWithPythonType], computation_base.Computation]] = None,
use_experimental_simulation_loop: bool = False
) -> tff.templates.IterativeProcess
This function creates a tff.templates.IterativeProcess
that performs
federated averaging on client models. The iterative process has the following
methods inherited from tff.templates.IterativeProcess
:
initialize
: Atff.Computation
with the functional type signature( -> S@SERVER)
, whereS
is atff.learning.framework.ServerState
representing the initial state of the server.next
: Atff.Computation
with the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)
whereS
is atff.learning.framework.ServerState
whose type matches that of the output ofinitialize
, and{B*}@CLIENTS
represents the client datasets, whereB
is the type of a single batch. This computation returns atff.learning.framework.ServerState
representing the updated server state and aggregated metrics at the server, including client training metrics and any other metrics from broadcast and aggregation processes.
The iterative process also has the following method not inherited from
tff.templates.IterativeProcess
:
get_model_weights
: Atff.Computation
that takes as input the atff.learning.framework.ServerState
, and returns atff.learning.ModelWeights
containing the state's model weights.
Each time the next
method is called, the server model is broadcast to each
client using a broadcast function. For each client, local training on one
pass of the pre-processed client dataset (multiple epochs are possible if the
dataset is pre-processed with repeat
operation) is performed via the
tf.keras.optimizers.Optimizer.apply_gradients
method of the client
optimizer. Each client computes the difference between the client model after
training and the initial broadcast model. These model deltas are then
aggregated at the server using some aggregation function. The aggregate model
delta is applied at the server by using the
tf.keras.optimizers.Optimizer.apply_gradients
method of the server
optimizer.
Args | |
---|---|
model_fn
|
A no-arg function that returns a tff.learning.Model . This method
must not capture TensorFlow tensors or variables and use them. The model
must be constructed entirely from scratch on each invocation, returning
the same pre-constructed model each call will result in an error.
|
client_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
callable that returns a tf.keras.Optimizer .
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
callable that returns a tf.keras.Optimizer . By default, this uses
tf.keras.optimizers.SGD with a learning rate of 1.0.
|
client_weighting
|
A value of tff.learning.ClientWeighting that specifies a
built-in weighting method, or a callable that takes the output of
model.report_local_unfinalized_metrics and returns a tensor that
provides the weight in the federated average of model deltas. If None ,
this will default base on model_update_aggregation_factory : If the
factory is a tff.aggregators.UnweightedAggregationFactory , this defaults
to a uniform weighting, otherwise it will weight clients by their number
of examples. An error will be raised if client_weighting is not uniform,
but model_update_aggregation_factory is unweighted.
|
broadcast_process
|
A tff.templates.MeasuredProcess that broadcasts the
model weights on the server to the clients. It must support the signature
(input_values@SERVER -> output_values@CLIENT) . If set to default None,
the server model is broadcast to the clients using the default
tff.federated_broadcast.
|
model_update_aggregation_factory
|
An optional
tff.aggregators.WeightedAggregationFactory or
tff.aggregators.UnweightedAggregationFactory that constructs
tff.templates.AggregationProcess for aggregating the client model
updates on the server. If None , uses tff.aggregators.MeanFactory .
|
metrics_aggregator
|
An optional function that takes in the metric finalizers
(i.e., tff.learning.Model.metric_finalizers() ) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of tff.learning.Model.report_local_unfinalized_metrics() ), and
returns a federated TFF computation of the following type signature
local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER . If
None , uses tff.learning.metrics.sum_then_finalize , which returns a
federated TFF computation that sums the unfinalized metrics from
CLIENTS , and then applies the corresponding metric finalizers at
SERVER .
|
use_experimental_simulation_loop
|
Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. |
Returns | |
---|---|
A tff.templates.IterativeProcess .
|