tff.learning.templates.compose_learning_process

Composes specialized measured processes into a learning process.

Used in the notebooks

Used in the tutorials

Given 4 specialized measured processes (described below) that make a learning process, and a computation that returns initial model weights to be used for training, this method validates that the processes fit together, and returns a LearningProcess. Please see the tutorial at https://www.tensorflow.org/federated/tutorials/composing_learning_algorithms for more details on composing learning processes.

The main purpose of the 4 measured processes are:

  • model_weights_distributor: Make global model weights at server available as the starting point for learning work to be done at clients.
  • client_work: Produce an update to the model received at clients.
  • model_update_aggregator: Aggregates the model updates from clients to the server.
  • model_finalizer: Updates the global model weights using the aggregated model update at server.

The next computation of the created learning process is composed from the next computations of the 4 measured processes, in order as visualized below. The type signatures of the processes must be such that this chaining is possible. Each process also reports its own metrics.

┌─────────────────────────┐
│model_weights_distributor│
└△─┬─┬────────────────────┘
 │ │┌▽──────────┐
 │ ││client_work│
 │ │└┬─────┬────┘
 │┌▽─▽────┐│
 ││metrics││
 │└△─△────┘│
 │ │┌┴─────▽────────────────┐
 │ ││model_update_aggregator│
 │ │└┬──────────────────────┘
┌┴─┴─▽──────────┐
│model_finalizer│
└┬──────────────┘
┌▽─────┐
│result│
└──────┘

The get_hparams computation of the created learning process produces a nested ordered dictionary containing the result of client_work.get_hparams and finalizer.get_hparams. The set_hparams computation operates similarly, by delegating to client_work.set_hparams and finalizer.set_hparams to set the hyperparameters in their associated states.

initial_model_weights_fn A tff.Computation that returns (unplaced) initial model weights.
model_weights_distributor A tff.learning.templates.DistributionProcess.
client_work A tff.learning.templates.ClientWorkProcess.
model_update_aggregator A tff.templates.AggregationProcess.
model_finalizer A tff.learning.templates.FinalizerProcess.

A tff.learning.templates.LearningProcess.

ClientSequenceTypeError If the first arg of the next method of the resulting LearningProcess is not a structure of sequences placed at tff.CLIENTS.