Module: oryx.experimental.nn.base

Contains the Template and Layer API for Oryx.

Modules are an abstraction provided by Oryx that enable encapsulating both state and functionality, and some basic neural network layers could be implemented with the stateful function API. However, we want a few extra pieces of functionality that important to writing practical neural networks beyond what exists in Module:

1) Handling keyword arguments. It is common to have neural networks have behavior conditional on a flag, like whether or not we are training the neural network. We can handle this using a custom Jaxpr interpreter.

2) Custom batch semantics. We can't implement a layer like batch normalization without having custom behavior for batches. We can accomplish this by using a batch rule for a custom JAX primitive associated with layers.

3) Special combinators. Building a custom Module abstraction enables overloading operators like >> to build complex architectures, along with handling explicit threading of PRNGKeys.

We implement these additions with the Template and Layer classes.


A Template is an object registered with core.state.init that can be initialized into a Layer. For example, for the template nn.Dense(20) can be initialized into a Layer by calling core.state.init(nn.Dense(20))(random.PRNGKey(...), ...), just like a stateful function. In most ways, Templates behave like stateful functions, in that you can call them, i.e. nn.Dense(20)(x, init_key=...) and it will execute a dense layer initialized with init_key. However, Templates have some extra functionality. We can use the >> operator to chain Templates together, i.e. nn.Dense(200) >> nn.Relu() is a new Template that composes a dense layer with a ReLU activation. It will appropriately split and thread the init_key to all its inner Templates when initialized.


A Layer is a subclass of Module with some additional functionality. Like Modules, Layers have a variables() method that returns a dictionary mapping names to state values. It also has a call_and_update function that returns the output of a computation and a new Layer with updated state. Underneath the hood, Layers do a couple extra things beyond Modules.

The first is that a Layer's call_and_update is associated with a JAX primitive: layer_cau_p. This primitive serves multiple purposes. The first is that it enables custom behavior when being vmap-ed. For example, a BatchNorm layer has different behavior when being vmap-ed over many examples vs. a single example. When a Layer implements a _call_and_update_batched method, the layer_cau_p primitive knows to use that method instead of mapping over the default _call_and_update method. This enables an inheritance pattern for custom behavior under JAX transformations. The layer_cau_p primitive also enables threading keyword arguments from the Layer's call site (like if we did layer(x, training=True)). This allows Layers to be implemented with keyword arguments without worrying if they are being traced.

Templates vs Layers

Layers cannot be constructed directly. In fact, we override their __new__ method to construct a Template instead. So, despite nn.Dense being a subclass of nn.Layer, nn.Dense(20) will return an instance of Template. The Template acts as a factory class for Layers and using core.state.init will actually build the Layer by calling special methods in the Layer.


All Layers must implement the initialize and spec class methods. These enable a Template to know the shapes and correctly initialize the parameters in a Layer. initialize is responsible for parameter initialization and spec is responsible for shape inference.

initialize must return a LayerParams object, which represents the state of a Layer, broken down into three components:

1) params - what we traditionally consider the "weights" of a layer, i.e. quantities we'd like to differentiate with respect to. 2) state - refers to numerical values that are part of a layer that we would not like to differentiate with respect to, such as running averages. These quantities are automatically stop-gradiented before running the forward pass. 2) info - refers to metadata that is not numerical, such as configuration strings.


All Layers must implement a _call method, which executes their forward pass. They can refer to the LayerParams returned in initialize. The _call method can accept keyword arguments such as training. We specially handle the keyword argument rng to ensure it is traced and split properly so it can be used for stochasticity in the forward pass.


If a Layer wants to optionally update its state, it can implement an _update method, which has the same input arguments as _call but instead returns a copy of the Layer with updated state. By default, it just returns self.


If a Layer would like a custom batching rule, it can implement _call_and_update_batched, which assumes all the input arguments have a leading batch dimension. It must return the batched outputs and an unbatched, updated Layer.


class Layer: Base class for neural network layers.

class LayerParams: LayerParams holds params and info of Layers.

class Template: Template class used by neural network layers.