View source on GitHub |
Module for transforming functions into FunctionModules.
In order to init
functions, we need to define a Module
subclass for them,
which is the FunctionModule
. The FunctionModule
encapsulates a Jaxpr that
is evaluated to execute the function, with special handling for keyword
arguments. This is useful for neural networks, where a keyword argument such as
training
may change the semantics of a function. The next will be
adding a rule into the kwargs_rules
dictionary, which is used in the
custom Jaxpr evaluator in FunctionModule
. The kwargs_rules
enables having
implementations for primitives that can change depending on the value of a
keyword argument. An example would be a neural network layer like dropout,
which has different behavior while training and not.
We also register functions with api.init
. The init
for functions first
inspects if the input function has a keyword argument init_key
, and only if
that is the case does it harvest the function. This results in an opt-in
behavior for functions to be stateful.
To see documentation of init
/spec
/call_and_update
and an example,
see api.py.
Classes
class FunctionModule
: Encapsulates a staged function.