Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


Runs federated training for a given tff.templates.IterativeProcess.

We assume that the iterative process has the following functional type signatures:

  • initialize: ( -> state).
  • next: <state, client_data> -> <state, metrics> where state matches the output type of initialize, and metrics has member that is a python mapping with string-valued keys.

This method performs up to total_rounds updates to the state of process. At each round_num, this update occurs by applying to state and the output of `client_selection_fn(round_num). We refer to this as a single "training step".

This method also records how long it takes (in seconds) to call client_selection_fn and at each round and add this to the round metrics with key tff.simulation.TRAIN_STEP_TIME_KEY. We also record how many training steps would occur per hour, which has key tff.simulation.TRAIN_STEPS_PER_HOUR_KEY.

This method uses up to two callbacks. The first, on_loop_start, accepts the initial state of process, and returns a starting state and round_num for the training loop. The callback can be used for things such as loading checkpoints.

The second callback, on_round_end is called after each training step. It accepts the output state and metrics of, and the current round number, and returns a new state and metrics mapping. This can be used for computing and saving additional metrics.

process A tff.templates.IterativeProcess instance to run. Must meet the type signature requirements documented above.
client_selection_fn Callable accepting an integer round number, and returning a list of client data to use as federated data for that round.
total_rounds The number of federated training rounds to perform.
on_loop_start An optional callable accepting the initial state of the iterative process, and returning a (potentially updated) state and an integer round_num used to determine where to resume the simulation loop.
on_round_end An optional callable accepting the state of the iterative process, an integer round number, and a mapping of metrics. The callable returns a (potentially updated) state of the same type, and a (potentially updated) mapping of metrics.

The state of the iterative process after training.