View source on GitHub
|
Encapsulates a parameterized function, along with updates to its state.
Inherits From: Pytree
oryx.core.state.Module(
*, name=None
)
Modules have the dual purpose of acting as containers for state and as functions of that state.
As containers, a Module's variables() method returns any encapsulated state,
and its flatten and unflatten method are used to register it as a Pytree,
so that it can be passed in and out of JAX-transformed functions. For example,
if we have a neural network layer module layer, then layer.variables()
returns the weights of the layer and grad(some_f)(layer) returns the
gradient of some_f with respect to those weights.
As functions, a Module has three methods: call, update, and
call_and_update. Conceptually a Module represents a parameterized function
f_variables(inputs), and its call method computes the output of the
function. A module's update method returns a new copy of the module for
a set of inputs with potentially new state (variables). The call_and_update
method returns both the output of the function and an updated module.
The __call__ method has some extra logic needed for composing stateful
functions. If a module m1 is called in the body of another module m2
we would like m2 to "inherit" the state in m1, so that either m1
or m1's variables appear in m2's variables. If a module has a non-None
name, then we'd like it to appear in m2's variables with name name,
and if not, we'd like variables in m1 to appear in m2 directly. To emulate
this behavior, we have the __call__ method call assign on the updated
module (or its member variables) to appropriately update the state to an
outer stateful context.
Methods
call
@abc.abstractmethodcall( *args, **kwargs ) -> Any
call_and_update
@abc.abstractmethodcall_and_update( *args, **kwargs ) -> Tuple[Any, 'Module']
flatten
@abc.abstractmethodflatten()
unflatten
@classmethod@abc.abstractmethodunflatten( data, xs )
update
@abc.abstractmethodupdate( *args, **kwargs ) -> 'Module'
variables
@abc.abstractmethodvariables() ->oryx.experimental.matching.jax_rewrite.Bindings
__call__
__call__(
*args, **kwargs
) -> Any
Emulates a regular function call.
A Module's dunder call will ensure state is updated after the function
call by calling assign on the updated state before returning the output of
the function.
| Args | |
|---|---|
*args
|
The arguments to the module. |
**kwargs
|
The keyword arguments to the module. |
| Returns | |
|---|---|
| The output of the module. |
View source on GitHub