Attend the Women in ML Symposium on December 7 Register now

tff.learning.templates.LearningProcess

Stay organized with collections Save and categorize content based on your preferences.

A stateful process for learning tasks that produces metrics.

Inherits From: IterativeProcess

This class inherits the constraints documented by tff.templates.IterativeProcess, including an initialize and next attribute. The LearningProcess also contains additional attributes, including get_model_weights and get_hparams. The former can be used to get out structures suitable for evaluation purposes, while the latter can be used to extract hyperparameters from the process. There are also corresponding set_model_weights and set_hparams attributes that can set these structures in a given state.

For example, given a LearningProcess process and client data data, we could call the following to initialize, optionally load other model weights, update the state three times, and extract the model weights of the state:

state = process.initialize()
# Optional: state = process.set_model_weights(state, other_weights)
for _ in range(3):
 state, metrics = process.next(state, data)
model_weights = process.get_model_weights(state)

initialize_fn A no-arg tff.Computation that creates the initial state of the learning process.
next_fn A tff.Computation that defines an iterated function. Given that initialize_fn returns a type S@SERVER, the next_fn must return a LearningProcessOutput where the state attribute is assignable from values with type S@SERVER, and accepts two arguments with types assignable from values with type S@SERVER and {D*}@CLIENTS.
get_model_weights A tff.Computation that accepts an input S whose type is assignable from the result of init_fn. This computation is used to create a representation of the state that can be used for downstream tasks without requiring access to the entire server state. For example, get_model_weights could be used to extract model weights suitable for computing evaluation metrics on held-out data.
set_model_weights A tff.Computation that accepts two inputs S and M where the type of S is assignable from values with the type returned by init_fn and M is a representation of the model weights stored in S. This updates the model weights representation within the state with the incoming value and returns a new value of type S.
get_hparams_fn An optional tff.Computation accepting the state S and returning the hyperparameters H. If not provided, this defaults to a computation that returns an empty ordered dictionary, regardless of the contents of the state.
set_hparams_fn An optional tff.Computation accepting the state S and hyperparameters H (matching the output of get_hparams_fn) and returning an updated state S. If not provided, this defaults to a pass-through computation that returns the input state regardless of the hparams passed in.

TypeError If initialize_fn and next_fn are not instances of tff.Computation.
TemplateInitFnParamNotEmptyError If initialize_fn has any input arguments.
TemplateStateNotAssignableError If the state returned by either initialize_fn or next_fn is not assignable to the first input argument of next_fn.
TemplateNextFnNumArgsError If next_fn does not have at exactly two input arguments.
LearningProcessPlacementError If the placements of initialize_fn and next_fn do not match the expected type placements.
LearningProcessOutputError If next_fn does not return a LearningProcessOutput.
LearningProcessSequenceTypeError If the second argument to next_fn is not a sequence type.

get_hparams A tff.Computation returning the hyperparameters of a server state.

This computation accepts an unplaced state of the process (originally produced by the initialize attribute), and returns an unplaced ordered dictionary representing the hyperparameters of the state.

get_model_weights A tff.Computation returning the model weights of a server state.

This computation accepts an unplaced state of the process (originally produced by the initialize attribute), and returns an unplaced representation of the model weights of the state. Note that this representation need not take the form of a tff.learning.models.ModelWeights object, and may depend on the specific LearningProcess in question.

initialize A tff.Computation that initializes the process.

This computation must have no input arguments, and its output must be the initial state of the learning process, placed at SERVER.

next A tff.Computation that runs one iteration of the process.

The first argument of this computation should always be the current state (originally produced by the initialize attribute), the second argument must be a tff.SequenceType placed at CLIENTS. The return type must be a LearningProcessOutput, with each field placed at SERVER.

set_hparams A tff.Computation that sets the hyperparamters of a server state.

This computation accepts two arguments: an unplaced state of the process (originally produced by the initialize attribute) and an ordered dictionary representing the hyperparameters (matching the output of get_hparams), and returns a new unplaced state with updated hyperparameters.

set_model_weights A tff.Computation that sets the model weights of a server state.

This computation accepts two arguments: an unplaced state of the process (originally produced by the initialize attribute) and a new structure of tensors representing the model weights, and returns new unplaced state with the updated model weights. Note that the model weights representation need not take the form of a tff.learning.models.ModelWeights object, and may depend on the specific LearningProcess in question.

state_type The tff.Type of the state of the process.