|View source on GitHub|
Runs federated training for a given
tff.templates.IterativeProcess, client_selection_fn: Callable[[int], Any], total_rounds: int, on_loop_start: Optional[Callable[[Any], Tuple[Any, int]]] = None, on_round_end: Optional[Callable[[Any, int, MetricsType], Tuple[Any, MetricsType]]] = None )
We assume that the iterative process has the following functional type signatures:
( -> state).
<state, client_data> -> <state, metrics>where state matches the output type of
metricshas member that is a python mapping with string-valued keys.
This method performs up to
total_rounds updates to the
round_num, this update occurs by applying
state and the output of `
client_selection_fn(round_num). We refer to this
as a single "training step".
This method also records how long it takes (in seconds) to call
process.next at each round and add this to the
round metrics with key
tff.simulation.TRAIN_STEP_TIME_KEY. We also record
how many training steps would occur per hour, which has key
This method uses up to two callbacks. The first,
on_loop_start, accepts the
initial state of
process, and returns a starting
the training loop. The callback can be used for things such as loading
The second callback,
on_round_end is called after each training step. It
accepts the output state and metrics of
process.next, and the current round
number, and returns a new state and metrics mapping. This can be used for
computing and saving additional metrics.
||Callable accepting an integer round number, and returning a list of client data to use as federated data for that round.|
||The number of federated training rounds to perform.|
An optional callable accepting the initial
An optional callable accepting the