Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

tff.learning.templates.ClientWorkProcess

A stateful process capturing work at clients during learning.

Inherits From: MeasuredProcess, IterativeProcess

Client work encapsulates the main work performed by clinets as part of a federated learning algorithm, such as several steps of gradient descent based on the client data, and returning a update to the initial model weights.

A ClientWorkProcess is a tff.templates.MeasuredProcess that formalizes the type signature of initialize_fn and next_fn for the core work performed by clients in a learning process.

The initialize_fn and next_fn must have the following type signatures:

  - initialize_fn: ( -> S@SERVER)
  - next_fn: (<S@SERVER,
               ModelWeights(TRAINABLE, NON_TRAINABLE)@CLIENTS,
               DATA@CLIENTS>
              ->
              <state=S@SERVER,
               result=ClientResult(TRAINABLE, W)@CLIENTS,
               measurements=M@SERVER>)

with W and M being arbitrary types not dependent on other types here.

ClientWorkProcess requires next_fn with a second and a third input argument, which are both values placed at CLIENTS. The second argument is initial model weights to be used for the work to be performed by clients. It must be of a type matching tff.learning.ModelWeights, for these to be assignable to the weights of a tff.learning.Model. The third argument must be a tff.SequenceType representing the data available at clients.

The result field of the returned tff.templates.MeasuredProcessOutput must be placed at CLIENTS, and be of type matching ClientResult, of which the update field represents the update to the trainable model weights, and update_weight represents the weight to be used for weighted aggregation of the updates.

The measurements field of the returned tff.templates.MeasuredProcessOutput must be placed at SERVER. Thus, implementation of this process must include aggregation of any metrics computed during training. Confirm this aspect, or change it.

initialize_fn A no-arg tff.Computation that returns the initial state of the measured process. Let the type of this state be called S.
next_fn A tff.Computation that represents the iterated function. The first or only argument must match the state type S. The return value must be a MeasuredProcessOutput whose state member matches the state type S.
next_is_multi_arg An optional boolean indicating that next_fn will receive more than just the state argument (if True) or only the state argument (if False). This parameter is primarily used to provide better error messages.

TypeError If initialize_fn and next_fn are not instances of tff.Computation.
TemplateInitFnParamNotEmptyError If initialize_fn has any input arguments.
TemplateStateNotAssignableError If the state returned by either initialize_fn or next_fn is not assignable to the first input argument of next_fn.
TemplateNotMeasuredProcessOutputError If next_fn does not return a MeasuredProcessOutput.

initialize A no-arg tff.Computation that returns the initial state.
next A tff.Computation that runs one iteration of the process.

Its first argument should always be the current state (originally produced by tff.templates.MeasuredProcess.initialize), and the return type must be a tff.templates.MeasuredProcessOutput.

state_type The tff.Type of the state of the process.