![]() |
Builds an iterative process that performs federated averaging.
tff.learning.build_federated_averaging_process(
model_fn: Callable[[], tff.learning.Model
],
client_optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
server_optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN,
*,
client_weighting: Optional[Union[ClientWeighting, ClientWeightFnType]] = None,
broadcast_process: Optional[tff.templates.MeasuredProcess
] = None,
aggregation_process: Optional[tff.templates.MeasuredProcess
] = None,
model_update_aggregation_factory: Optional[tff.aggregators.WeightedAggregationFactory
] = None,
use_experimental_simulation_loop: bool = False
) -> tff.templates.IterativeProcess
Used in the notebooks
Used in the tutorials |
---|
This function creates a tff.templates.IterativeProcess
that performs
federated averaging on client models. The iterative process has the following
methods inherited from tff.templates.IterativeProcess
:
initialize
: Atff.Computation
with the functional type signature( -> S@SERVER)
, whereS
is atff.learning.framework.ServerState
representing the initial state of the server.next
: Atff.Computation
with the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)
whereS
is atff.learning.framework.ServerState
whose type matches that of the output ofinitialize
, and{B*}@CLIENTS
represents the client datasets, whereB
is the type of a single batch. This computation returns atff.learning.framework.ServerState
representing the updated server state and metrics that are the result oftff.learning.Model.federated_output_computation
during client training and any other metrics from broadcast and aggregation processes.
The iterative process also has the following method not inherited from
tff.templates.IterativeProcess
:
get_model_weights
: Atff.Computation
that takes as input the atff.learning.framework.ServerState
, and returns atff.learning.ModelWeights
containing the state's model weights.
Each time the next
method is called, the server model is broadcast to each
client using a broadcast function. For each client, one epoch of local
training is performed via the tf.keras.optimizers.Optimizer.apply_gradients
method of the client optimizer. Each client computes the difference between
the client model after training and the initial broadcast model. These model
deltas are then aggregated at the server using some aggregation function. The
aggregate model delta is applied at the server by using the
tf.keras.optimizers.Optimizer.apply_gradients
method of the server
optimizer.
Args | |
---|---|
model_fn
|
A no-arg function that returns a tff.learning.Model . This method
must not capture TensorFlow tensors or variables and use them. The model
must be constructed entirely from scratch on each invocation, returning
the same pre-constructed model each call will result in an error.
|
client_optimizer_fn
|
A no-arg callable that returns a tf.keras.Optimizer .
|
server_optimizer_fn
|
A no-arg callable that returns a tf.keras.Optimizer .
By default, this uses tf.keras.optimizers.SGD with a learning rate of
1.0.
|
client_weighting
|
A value of tff.learning.ClientWeighting that specifies a
built-in weighting method, or a callable that takes the output of
model.report_local_outputs and returns a tensor that provides the weight
in the federated average of model deltas. If None, defaults to weighting
by number of examples.
|
broadcast_process
|
a tff.templates.MeasuredProcess that broadcasts the
model weights on the server to the clients. It must support the signature
(input_values@SERVER -> output_values@CLIENT) .
|
aggregation_process
|
a tff.templates.MeasuredProcess that aggregates the
model updates on the clients back to the server. It must support the
signature ({input_values}@CLIENTS-> output_values@SERVER) . Must be
None if model_update_aggregation_factory is not None.
|
model_update_aggregation_factory
|
An optional
tff.aggregators.WeightedAggregationFactory or
tff.aggregators.UnweightedAggregationFactory that constructs
tff.templates.AggregationProcess for aggregating the client model
updates on the server. If None , uses tff.aggregators.MeanFactory . Must
be None if aggregation_process is not None.
|
use_experimental_simulation_loop
|
Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. |
Returns | |
---|---|
A tff.templates.IterativeProcess .
|