View source on GitHub
|
Reimplementation of a subset of the optax library using Oryx's state system.
This module is an advanced example of how to write stateful code using Oryx. For a more complete and supported optimizers package that includes additional transformations and other features, please take a look at Optax.
Functions
add_noise(...): Returns a function that adds noise to updates.
apply_every(...): Returns a function that accumulates updates and applies them all at once.
chain(...): Composes update functions together serially.
clip_by_global_norm(...): Returns a function that clips updates to a provided max norm.
optimize(...): Runs several iterations of optimization and returns the result.
scale_by_adam(...): Scales updates according to Adam update rules.
scale_by_rms(...): Returns a function that scales updates by the RMS of the updates.
scale_by_schedule(...): Returns a function that scales updates according to an input schedule.
scale_by_stddev(...): Returns a function that scales updates by their standard deviation.
trace(...): Returns a function that combines updates with a running state.
View source on GitHub