Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

Builds a learning process that performs Mime Lite.

This function creates a tff.learning.templates.LearningProcess that performs Mime Lite algorithm on client models. The iterative process has the following methods inherited from tff.learning.templates.LearningProcess:

Each time the next method is called, the server model is communicated to each client using the provided model_distributor. For each client, local training is performed using optimizer, where its state is communicated by the server, and kept intact during local training. The state is updated only at the server based on the full gradient evaluated by the clients based on the current server model state. The client full gradients are aggregated by unweighted full_gradient_aggregator. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated by unweighted model_aggregator. The aggregate model delta is added to the existing server model state.

The Mime Lite algorithm is based on the paper "Breaking the centralized barrier for cross-device federated learning." Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian U. Stich, and Ananda Theertha Suresh. Advances in Neural Information Processing Systems 34 (2021).

Note that Keras optimizers are not supported. This is due to the Mime Lite algorithm applying the optimizer without changing it state at clients (optimizer's tf.Variables in the case of Keras), which is not possible with Keras optimizers without reaching into private implementation details and incurring additional computation and memory cost at clients.

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
base_optimizer A tff.learning.optimizers.Optimizer which will be used for both creating and updating a global optimizer state, as well as optimization at clients given the global state, which is fixed during the optimization.
server_optimizer A tff.learning.optimizers.Optimizer which will be used for applying the aggregate model update to the global model weights.
model_distributor An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via distributors.build_broadcast_process.
model_aggregator An optional tff.aggregators.UnweightedAggregationFactory used to aggregate client updates on the server. If None, this is set to tff.aggregators.UnweightedMeanFactory.
full_gradient_aggregator An optional tff.aggregators.UnweightedAggregationFactory used to aggregate the full gradients on client datasets. If None, this is set to tff.aggregators.UnweightedMeanFactory.
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize.
use_experimental_simulation_loop Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations.

A tff.learning.templates.LearningProcess.