![]() |
A stateful process for learning tasks that produces metrics.
Inherits From: IterativeProcess
tff.learning.templates.LearningProcess(
initialize_fn: tff.Computation
,
next_fn: tff.Computation
,
get_model_weights: tff.Computation
,
set_model_weights: tff.Computation
)
This class inherits the constraints documented by
tff.templates.IterativeProcess
, including an initialize
and next
attribute. The LearningProcess
also contains an additional
get_model_weights
attribute.
All of initialize
, next
and get_model_weights
must be
tff.Computation
s, with the following type signatures:
- initialize:
( -> S@SERVER)
- next:
(<S@SERVER, {D*}@CLIENTS> -> <state=S@SERVER, metrics=M@SERVER>)
- get_model_weights:
(S -> M)
where{D*}@CLIENTS
represents the sequence of data at a client, withD
denoting the type of a single member of that sequence, andM
representing the (unplaced) output type of theget_model_weights
function.
Note that here, "model weights" is a loosely-defined term intended to refer to some kind of "representation" of the model being learned. This is typically some nested structure of tensors, and is often suitable for evaluation purposes.
For example, given a LearningProcess process
and client data data
, we
could call the following to initialize, optionally load other model weights,
update the state three times, and extract the model weights of the state:
state = process.initialize()
# Optional: state = process.set_model_weights(state, other_weights)
for _ in range(3):
state, metrics = process.next(state, data)
model_weights = process.get_model_weights(state)
Args | |
---|---|
initialize_fn
|
A no-arg tff.Computation that creates the initial state
of the learning process.
|
next_fn
|
A tff.Computation that defines an iterated function. Given that
initialize_fn returns a type S@SERVER , the next_fn must return a
LearningProcessOutput where the state attribute is assignable from
values with type S@SERVER , and accepts two arguments with types
assignable from values with type S@SERVER and {D*}@CLIENTS .
|
get_model_weights
|
A tff.Computation that accepts an input S whose
type is assignable from the result of init_fn . This computation is
used to create a representation of the state that can be used for
downstream tasks without requiring access to the entire server state.
For example, get_model_weights could be used to extract model weights
suitable for computing evaluation metrics on held-out data.
|
set_model_weights
|
A tff.Computation that accepts two inputs S and M
where the type of S is assignable from values with the type returned
by init_fn and M is a representation of the model weights stored in
S . This updates the model weights representation within the state with
the incoming value and returns a new value of type S .
|
Raises | |
---|---|
TypeError
|
If initialize_fn and next_fn are not instances of
tff.Computation .
|
TemplateInitFnParamNotEmptyError
|
If initialize_fn has any input
arguments.
|
TemplateStateNotAssignableError
|
If the state returned by either
initialize_fn or next_fn is not assignable to the first input
argument of next_fn .
|
TemplateNextFnNumArgsError
|
If next_fn does not have at exactly two
input arguments.
|
LearningProcessPlacementError
|
If the placements of initialize_fn and
next_fn do not match the expected type placements.
|
LearningProcessOutputError
|
If next_fn does not return a
LearningProcessOutput .
|
LearningProcessSequenceTypeError
|
If the second argument to next_fn is
not a sequence type.
|
Attributes | |
---|---|
get_model_weights
|
A tff.Computation returning the model weights of a server state.
This computation accepts an unplaced state of the process (originally
produced by the |
initialize
|
A tff.Computation that initializes the process.
This computation must have no input arguments, and its output must be the
initial state of the iterative process, placed at |
next
|
A tff.Computation that runs one iteration of the process.
The first argument of this computation should always be the current state
(originally produced by the |
set_model_weights
|
A tff.Computation that sets the model weights of a server state.
This computation accepts two arguments: an unplaced state of the process
(originally produced by the |
state_type
|
The tff.Type of the state of the process.
|