tff.learning.templates.build_apply_optimizer_finalizer
Stay organized with collections
Save and categorize content based on your preferences.
Builds finalizer that applies a step of an optimizer.
tff.learning.templates.build_apply_optimizer_finalizer(
optimizer_fn: tff.learning.optimizers.Optimizer
,
model_weights_type: tff.types.StructType
,
should_reject_update: Callable[[Any, Any], tuple[Union[bool, tf.Tensor], Optional[_MeasurementsType]]
] = tff.learning.templates.reject_non_finite_update
)
Used in the notebooks
The provided model_weights_type
must be a non-federated tff.Type
with the
tff.learning.models.ModelWeights
container.
The 2nd input argument of the created FinalizerProcess.next
expects a value
matching model_weights_type
and its 3rd argument expects value matching
model_weights_type.trainable
. The optimizer
will be applied to the
trainable model weights only, leaving non_trainable weights unmodified.
The state of the process is the state of the optimizer
and the process
returns empty measurements.
Args |
optimizer_fn
|
A tff.learning.optimizers.Optimizer . This optimizer is used
to apply client updates to the server model.
|
model_weights_type
|
A non-federated tff.Type of the model weights to be
optimized, which must have a tff.learning.models.ModelWeights container.
|
should_reject_update
|
A callable that takes the optimizer state and the
model weights update, and returns a boolean or a bool tensor indicating if
the model weights update should be rejected and an OrderedDict of
measurements. If the model weights update is reject, we will fall back to
the previous round's optimizer state and model weight, this is a no-op
otherwise. The default function is reject_non_finite_update which checks
if there is any non-finite value in the model update and returns the
results.
|
Returns |
A FinalizerProcess that applies the optimizer .
|
Raises |
TypeError
|
If value_type does not have a
tff.learning.model.sModelWeights
Python container, or contains a tff.types.FederatedType .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-09-20 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-20 UTC."],[],[]]