View source on GitHub |
Builds a learning process for FedAvg with client optimizer scheduling.
tff.learning.algorithms.build_weighted_fed_avg_with_optimizer_schedule(
model_fn: Union[Callable[[], tff.learning.models.VariableModel
], tff.learning.models.FunctionalModel
],
client_learning_rate_fn: Callable[[int], float],
client_optimizer_fn: Callable[[float], tff.learning.optimizers.Optimizer
],
server_optimizer_fn: Optional[tff.learning.optimizers.Optimizer
] = None,
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory
] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType
] = None,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
The primary purpose of this implementation of FedAvg is that it allows for the
client optimizer to be scheduled across rounds. Notably, the local
optimizaiton step uses a constant learning rate within a round, which is
scheduled across rounds of federated training. The process keeps track of how
many iterations of .next
have occurred (starting at 0
), and for each such
round_num
, the clients will use client_optimizer_fn(round_num)
to perform
local optimization. This allows learning rate scheduling (eg. starting with a
large learning rate and decaying it over time) as well as a small learning
rate (eg. switching optimizers as learning progresses).
This function creates a LearningProcess
that performs federated averaging on
client models. The iterative process has the following methods inherited from
LearningProcess
:
initialize
: Atff.Computation
with the functional type signature( -> S@SERVER)
, whereS
is aLearningAlgorithmState
representing the initial state of the server.next
: Atff.Computation
with the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
and{B*}@CLIENTS
represents the client datasets. The outputL
contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes.get_model_weights
: Atff.Computation
with type signature(S -> M)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andnext
, andM
represents the type of the model weights used during training.set_model_weights
: Atff.Computation
with type signature(<S, M> -> S)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andM
represents the type of the model weights used during training.
Each time the next
method is called, the server model is broadcast to each
client using a broadcast function. For each client, local training is
performed using client_optimizer_fn
. Each client computes the difference
between the client model after training and the initial broadcast model.
These model deltas are then aggregated at the server using a weighted
aggregation function. Clients weighted by the number of examples they see
thoughout local training. The aggregate model delta is applied at the server
using a server optimizer.
Args | ||
---|---|---|
model_fn
|
Either a no-arg function that returns a
tff.learning.models.VariableModel , or an instance of
tff.learning.FunctionalModel . The no-arg function must not capture
TensorFlow tensors or variables and use them. The model must be
constructed entirely from scratch on each invocation, returning the same
pre-constructed model instance each call will result in an error.
|
|
client_learning_rate_fn
|
A callable accepting an integer round number and
returning a float to be used as a learning rate for the optimizer. The
client work will call optimizer_fn(learning_rate_fn(round_num)) where
round_num is the integer round number. Note that the round numbers
supplied will start at 0 and increment by one each time .next is
called on the resulting process. Also note that this function must be
serializable by TFF.
|
|
client_optimizer_fn
|
A callable accepting a float learning rate, and
returning a tff.learning.optimizers.Optimizer.
</td>
</tr><tr>
<td> server_optimizer_fn<a id="server_optimizer_fn"></a>
</td>
<td>
A <a href="../../../tff/learning/optimizers/Optimizer"><code>tff.learning.optimizers.Optimizer</code></a>. By default, this
uses <a href="../../../tff/learning/optimizers/build_sgdm"><code>tff.learning.optimizers.build_sgdm</code></a> with a learning rate of 1.0.
</td>
</tr><tr>
<td> model_distributor<a id="model_distributor"></a>
</td>
<td>
An optional DistributionProcessthat distributes the
model weights on the server to the clients. If set to None, the
distributor is constructed via distributors.build_broadcast_process.
</td>
</tr><tr>
<td> model_aggregator<a id="model_aggregator"></a>
</td>
<td>
An optional <a href="../../../tff/aggregators/WeightedAggregationFactory"><code>tff.aggregators.WeightedAggregationFactory</code></a>
used to aggregate client updates on the server. If None, this is set to
<a href="../../../tff/aggregators/MeanFactory"><code>tff.aggregators.MeanFactory</code></a>.
</td>
</tr><tr>
<td> metrics_aggregator<a id="metrics_aggregator"></a>
</td>
<td>
A function that takes in the metric finalizers (i.e.,
<a href="../../../tff/learning/models/VariableModel#metric_finalizers"><code>tff.learning.models.VariableModel.metric_finalizers()</code></a>) and a
<a href="../../../tff/types/StructWithPythonType"><code>tff.types.StructWithPythonType</code></a> of the unfinalized metrics (i.e., the TFF
type of
<a href="../../../tff/learning/models/VariableModel#report_local_unfinalized_metrics"><code>tff.learning.models.VariableModel.report_local_unfinalized_metrics()</code></a>),
and returns a <a href="../../../tff/Computation"><code>tff.Computation</code></a> for aggregating the unfinalized metrics.
If None, this is set to <a href="../../../tff/learning/metrics/sum_then_finalize"><code>tff.learning.metrics.sum_then_finalize</code></a>.
</td>
</tr><tr>
<td> loop_implementation`
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
---|---|
A LearningProcess .
|