|View source on GitHub|
Encapsulates a parameterized function, along with updates to its state.
oryx.core.state.Module( *, name=None )
Modules have the dual purpose of acting as containers for state and as functions of that state.
As containers, a Module's
variables() method returns any encapsulated state,
unflatten method are used to register it as a Pytree,
so that it can be passed in and out of JAX-transformed functions. For example,
if we have a neural network layer module
returns the weights of the layer and
grad(some_f)(layer) returns the
some_f with respect to those weights.
As functions, a
Module has three methods:
call_and_update. Conceptually a
Module represents a parameterized function
f_variables(inputs), and its
call method computes the output of the
function. A module's
update method returns a new copy of the module for
a set of inputs with potentially new state (variables). The
method returns both the output of the function and an updated module.
__call__ method has some extra logic needed for composing stateful
functions. If a module
m1 is called in the body of another module
we would like
m2 to "inherit" the state in
m1, so that either
m1's variables appear in
m2's variables. If a module has a non-
name, then we'd like it to appear in
m2's variables with name
and if not, we'd like variables in
m1 to appear in
m2 directly. To emulate
this behavior, we have the
__call__ method call
assign on the updated
module (or its member variables) to appropriately update the state to an
outer stateful context.
call( *args, **kwargs ) -> Any
call_and_update( *args, **kwargs ) -> Tuple[Any, 'Module']
unflatten( data, xs )
update( *args, **kwargs ) -> 'Module'
__call__( *args, **kwargs ) -> Any
Emulates a regular function call.
Module's dunder call will ensure state is updated after the function
call by calling
assign on the updated state before returning the output of
||The arguments to the module.|
||The keyword arguments to the module.|
|The output of the module.|