tff.learning.templates.ClientWorkProcess

A stateful process capturing work at clients during learning.

Inherits From: MeasuredProcess, IterativeProcess

Used in the notebooks

Used in the tutorials

Client work encapsulates the main work performed by clients as part of a federated learning algorithm, such as several steps of gradient descent based on the client data, and returning a update to the initial model weights.

A ClientWorkProcess is a tff.templates.MeasuredProcess that formalizes the type signature of initialize and next for the core work performed by clients in a learning process.

initialize_fn A tff.Computation matching the criteria above.
next_fn A tff.Computation matching the criteria above.
get_hparams_fn An optional tff.Computation matching the criteria above. If not provided, this defaults to a computation that returns an empty ordred dictionary, regardless of the contents of the state.
set_hparams_fn An optional tff.Computation matching the criteria above. If not provided, this defaults to a pass-through computation, that returns the input state regardless of the hparams passed in.

TemplateNotFederatedError If any of the federated computations provided do not return a federated type.
TemplateNextFnNumArgsError If the next_fn has an incorrect number of arguments.
TemplatePlacementError If any of the federated computations have an incorrect placement.
ClientDataTypeError If the third input of next_fn is not a sequence type placed at CLIENTS.
ClientResultTypeError If the second output of next_fn does not meet the criteria outlined above.
GetHparamsTypeError If the type signature of get_hparams_fn does not meet the criteria above.
SetHparamsTypeError If the type signature of set_hparams_fn does not meet the criteria above.

get_hparams

initialize A no-arg tff.Computation that returns the initial state.
next A tff.Computation that produces the next state.

Its first argument should always be the current state (originally produced by tff.templates.IterativeProcess.initialize), and the first (or only) returned value is the updated state.

set_hparams

state_type The tff.Type of the state of the process.