Module: oryx.experimental.optimizers.optix

Reimplementation of jax.experimental.optix.

There are a few key differences:

  1. We use Oryx's state API to handle variables.
  2. We change the order of the arguments of the optimiser's update function.

Functions

adam(...)

add_noise(...): Returns a function that adds noise to updates.

apply_every(...): Returns a function that accumulates updates and applies them all at once.

chain(...): Composes update functions together serially.

clip(...)

clip_by_global_norm(...): Returns a function that clips updates to a provided max norm.

global_norm(...)

gradient_descent(...)

noisy_sgd(...)

optimize(...): Runs several iterations of optimization and returns the result.

rmsprop(...)

scale(...)

scale_by_adam(...): Scales updates according to Adam update rules.

scale_by_rms(...): Returns a function that scales updates by the RMS of the updates.

scale_by_schedule(...): Returns a function that scales updates according to an input schedule.

scale_by_stddev(...): Returns a function that scales updates by their standard deviation.

sgd(...)

trace(...): Returns a function that combines updates with a running state.