Module: oryx.core.state.function

Module for transforming functions into FunctionModules.

In order to init functions, we need to define a Module subclass for them, which is the FunctionModule. The FunctionModule encapsulates a Jaxpr that is evaluated to execute the function, with special handling for keyword arguments. This is useful for neural networks, where a keyword argument such as training may change the semantics of a function. In order to hook into this keyword functionality, you can first register a custom unzip rule using the custom_unzip_rules dictionary that will be used while init-ing, which will allow substituting out primitives in the trace for others. The next will be adding a rule into the kwargs_rules dictionary, which is used in the custom Jaxpr evaluator in FunctionModule. The kwargs_rules enables having implementations for primitives that can change depending on the value of a keyword argument. An example would be a neural network layer like dropout, which has different behavior while training and not.

We also register functions with api.init. The init for functions first inspects if the input function has a keyword argument init_key, and only if that is the case does it unzip the function. This results in an opt-in behavior for functions to be stateful; additionally, the init_key is used for data-dependence to unzip the function properly.

To see documentation of init/spec/call_and_update and an example, see api.py.

Classes

class FunctionModule: Encapsulates a staged function.