|View source on GitHub|
Module for transforming functions into FunctionModules.
In order to
init functions, we need to define a
Module subclass for them,
which is the
FunctionModule encapsulates a Jaxpr that
is evaluated to execute the function, with special handling for keyword
arguments. This is useful for neural networks, where a keyword argument such as
training may change the semantics of a function. In order to hook into this
keyword functionality, you can first register a custom
unzip rule using the
custom_unzip_rules dictionary that will be used while
init-ing, which will
allow substituting out primitives in the trace for others. The next will be
adding a rule into the
kwargs_rules dictionary, which is used in the
custom Jaxpr evaluator in
kwargs_rules enables having
implementations for primitives that can change depending on the value of a
keyword argument. An example would be a neural network layer like dropout,
which has different behavior while training and not.
We also register functions with
init for functions first
inspects if the input function has a keyword argument
init_key, and only if
that is the case does it unzip the function. This results in an opt-in behavior
for functions to be stateful; additionally, the
init_key is used for
data-dependence to unzip the function properly.
To see documentation of
call_and_update and an example,
class FunctionModule: Encapsulates a staged function.