oryx.core.state.Module

Encapsulates a parameterized function, along with updates to its state.

Inherits From: Pytree

Modules have the dual purpose of acting as containers for state and as functions of that state.

As containers, a Module's variables() method returns any encapsulated state, and its flatten and unflatten method are used to register it as a Pytree, so that it can be passed in and out of JAX-transformed functions. For example, if we have a neural network layer module layer, then layer.variables() returns the weights of the layer and grad(some_f)(layer) returns the gradient of some_f with respect to those weights.

As functions, a Module has three methods: call, update, and call_and_update. Conceptually a Module represents a parameterized function f_variables(inputs), and its call method computes the output of the function. A module's update method returns a new copy of the module for a set of inputs with potentially new state (variables). The call_and_update method returns both the output of the function and an updated module.

The __call__ method has some extra logic needed for composing stateful functions. If a module m1 is called in the body of another module m2 we would like m2 to "inherit" the state in m1, so that either m1 or m1's variables appear in m2's variables. If a module has a non-None name, then we'd like it to appear in m2's variables with name name, and if not, we'd like variables in m1 to appear in m2 directly. To emulate this behavior, we have the __call__ method call assign on the updated module (or its member variables) to appropriately update the state to an outer stateful context.

Methods

call

View source

call_and_update

View source

flatten

View source

unflatten

View source

update

View source

variables

View source

__call__

View source

Emulates a regular function call.

A Module's dunder call will ensure state is updated after the function call by calling assign on the updated state before returning the output of the function.

Args
*args The arguments to the module.
**kwargs The keyword arguments to the module.

Returns
The output of the module.