Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

tff.learning.templates.compose_learning_process

Composes specialized measured processes into a learning process.

Given the 4 specialized measured processes that make a learning process as documented in [ model weights to be used for training, this method validates that the processes fit together, and returns a LearningProcess.

The main purpose of the 4 measured processes are:

  • model_weights_distributor: Make global model weights at server available as the starting point for learning work to be done at clients.
  • client_work: Produce an update to the model received at clients.
  • model_update_aggregator: Aggregates the model updates from clients to the server.
  • model_finalizer: Updates the global model weights using the aggregated model update at server.

The next computation of the created learning process is composed from the next computations of the 4 measured processes, in order as visualized below. The type signatures of the processes must be such that this chaining is possible. Each process also reports its own metrics.

┌─────────────────────────┐
│model_weights_distributor│
└△─┬─┬────────────────────┘
 │ │┌▽──────────┐
 │ ││client_work│
 │ │└┬─────┬────┘
 │┌▽─▽────┐│
 ││metrics││
 │└△─△────┘│
 │ │┌┴─────▽────────────────┐
 │ ││model_update_aggregator│
 │ │└┬──────────────────────┘
┌┴─┴─▽──────────┐
│model_finalizer│
└┬──────────────┘
┌▽─────┐
│result│
└──────┘

initial_model_weights_fn A tff.Computation that returns (unplaced) initial model weights.
model_weights_distributor A DistributionProcess.
client_work A ClientWorkProcess.
model_update_aggregator A tff.templates.AggregationProcess.
model_finalizer A FinalizerProcess.

A LearningProcess.