ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tff.simulation.run_simulation

Runs a federated training simulation for a given iterative process.

We assume that the iterative process has the following functional type signatures:

  • initialize: ( -> state).
  • next: <state, client_data> -> <state, metrics> where state matches the output type of initialize, and metrics has member that is a python mapping with string-valued keys.

This method performs up to total_rounds updates to the state of process. At each round_num, this update occurs by applying process.next to state and the output of `client_selection_fn(round_num). We refer to this as a single "training step".

This method also records how long it takes (in seconds) to call client_selection_fn and process.next at each round and add this to the round metrics with key tff.simulation.TRAIN_STEP_TIME_KEY. We also record how many training steps would occur per hour, which has key tff.simulation.TRAIN_STEPS_PER_HOUR_KEY.

In full generality, after each round, we compute validation metrics via validation_fn (if not None), add these to the metrics created by process.next (prefixing with tff.simulation.VALIDATION_METRICS_KEY), save the combined metrics using the metrics_managers (if not None), and save a checkpoint via file_checkpoint_manager (if not None).

process A tff.templates.IterativeProcess instance to run.
client_selection_fn Callable accepting an integer round number, and returning a list of client data to use as federated data for that round.
total_rounds The number of federated training rounds to perform.
file_checkpoint_manager An optional tff.simulation.FileCheckpointManager used to periodically save checkpoints of the iterative process state.
metrics_managers An optional list of tff.simulation.MetricsManager objects used to save training metrics throughout the simulation.
validation_fn An optional callable accepting the current state of the iterative process (ie. the first output argument of iterative_process.next) and the current round number, and returning a mapping of validation metrics.

The state of the iterative process after training.