|View source on GitHub|
Module for transforming functions into FunctionModules.
In order to
init functions, we need to define a
Module subclass for them,
which is the
FunctionModule encapsulates a Jaxpr that
is evaluated to execute the function, with special handling for keyword
arguments. This is useful for neural networks, where a keyword argument such as
training may change the semantics of a function. The next will be
adding a rule into the
kwargs_rules dictionary, which is used in the
custom Jaxpr evaluator in
kwargs_rules enables having
implementations for primitives that can change depending on the value of a
keyword argument. An example would be a neural network layer like dropout,
which has different behavior while training and not.
We also register functions with
init for functions first
inspects if the input function has a keyword argument
init_key, and only if
that is the case does it harvest the function. This results in an opt-in
behavior for functions to be stateful.
To see documentation of
call_and_update and an example,
class FunctionModule: Encapsulates a staged function.