View source on GitHub |
Encapsulates a parameterized function, along with updates to its state.
Inherits From: Pytree
oryx.core.state.Module(
*, name=None
)
Modules have the dual purpose of acting as containers for state and as functions of that state.
As containers, a Module's variables()
method returns any encapsulated state,
and its flatten
and unflatten
method are used to register it as a Pytree,
so that it can be passed in and out of JAX-transformed functions. For example,
if we have a neural network layer module layer
, then layer.variables()
returns the weights of the layer and grad(some_f)(layer)
returns the
gradient of some_f
with respect to those weights.
As functions, a Module
has three methods: call
, update
, and
call_and_update
. Conceptually a Module
represents a parameterized function
f_variables(inputs)
, and its call
method computes the output of the
function. A module's update
method returns a new copy of the module for
a set of inputs with potentially new state (variables). The call_and_update
method returns both the output of the function and an updated module.
The __call__
method has some extra logic needed for composing stateful
functions. If a module m1
is called in the body of another module m2
we would like m2
to "inherit" the state in m1
, so that either m1
or m1
's variables appear in m2
's variables. If a module has a non-None
name
, then we'd like it to appear in m2
's variables with name name
,
and if not, we'd like variables in m1
to appear in m2
directly. To emulate
this behavior, we have the __call__
method call assign
on the updated
module (or its member variables) to appropriately update the state to an
outer stateful context.
Methods
call
@abc.abstractmethod
call( *args, **kwargs ) -> Any
call_and_update
@abc.abstractmethod
call_and_update( *args, **kwargs ) -> Tuple[Any, 'Module']
flatten
@abc.abstractmethod
flatten()
unflatten
@classmethod
@abc.abstractmethod
unflatten( data, xs )
update
@abc.abstractmethod
update( *args, **kwargs ) -> 'Module'
variables
@abc.abstractmethod
variables() ->
oryx.experimental.matching.jax_rewrite.Bindings
__call__
__call__(
*args, **kwargs
) -> Any
Emulates a regular function call.
A Module
's dunder call will ensure state is updated after the function
call by calling assign
on the updated state before returning the output of
the function.
Args | |
---|---|
*args
|
The arguments to the module. |
**kwargs
|
The keyword arguments to the module. |
Returns | |
---|---|
The output of the module. |