|View source on GitHub|
Runs a federated training simulation for a given iterative process.
tff.templates.IterativeProcess, client_selection_fn: Callable[[int], Any], total_rounds: int, file_checkpoint_manager: Optional[
tff.simulation.FileCheckpointManager] = None, metrics_managers: Optional[List[MetricsManager]] = None, validation_fn: Optional[ValidationFnType] = None )
We assume that the iterative process has the following functional type signatures:
( -> state).
<state, client_data> -> <state, metrics>where state matches the output type of
metricshas member that is a python mapping with string-valued keys.
This method performs up to
total_rounds updates to the
round_num, this update occurs by applying
state and the output of `
client_selection_fn(round_num). We refer to this
as a single "training step".
This method also records how long it takes (in seconds) to call
process.next at each round and add this to the
round metrics with key
tff.simulation.TRAIN_STEP_TIME_KEY. We also record
how many training steps would occur per hour, which has key
In full generality, after each round, we compute validation metrics via
validation_fn (if not
None), add these to the metrics created by
process.next (prefixing with
the combined metrics using the
metrics_managers (if not
None), and save a
file_checkpoint_manager (if not
||Callable accepting an integer round number, and returning a list of client data to use as federated data for that round.|
||The number of federated training rounds to perform.|
An optional list of
An optional callable accepting the current state of the
iterative process (ie. the first output argument of