|View source on GitHub|
Layer API for Oryx.
Modules are an abstraction provided by Oryx that enable encapsulating both
state and functionality, and some basic neural network layers could be
implemented with the stateful function API. However, we want a few extra pieces
of functionality that important to writing practical neural networks beyond
what exists in
1) Handling keyword arguments. It is common to have neural networks have behavior conditional on a flag, like whether or not we are training the neural network. We can handle this using a custom Jaxpr interpreter.
2) Custom batch semantics. We can't implement a layer like batch normalization without having custom behavior for batches. We can accomplish this by using a batch rule for a custom JAX primitive associated with layers.
3) Special combinators. Building a custom
Module abstraction enables
overloading operators like
>> to build complex architectures, along with
handling explicit threading of
We implement these additions with the
Template is an object registered with
core.state.init that can
be initialized into a
Layer. For example, for the template
can be initialized into a
Layer by calling
core.state.init(nn.Dense(20))(random.PRNGKey(...), ...), just like a
stateful function. In most ways,
Templates behave like stateful functions,
in that you can call them, i.e.
nn.Dense(20)(x, init_key=...) and it will
execute a dense layer initialized with
some extra functionality. We can use the
>> operator to chain
nn.Dense(200) >> nn.Relu() is a new
Template that composes
a dense layer with a ReLU activation. It will appropriately split and thread
init_key to all its inner
Templates when initialized.
Layer is a subclass of
Module with some additional functionality. Like
Layers have a
variables() method that returns a dictionary
mapping names to state values. It also has a
call_and_update function that
returns the output of a computation and a new
Layer with updated state.
Underneath the hood,
Layers do a couple extra things beyond
The first is that a
call_and_update is associated with a JAX
layer_cau_p. This primitive serves multiple purposes. The first is
that it enables custom behavior when being
vmap-ed. For example, a
BatchNorm layer has different behavior when being
vmap-ed over many
examples vs. a single example. When a
Layer implements a
_call_and_update_batched method, the
layer_cau_p primitive knows to use that
method instead of mapping over the default
_call_and_update method. This
enables an inheritance pattern for custom behavior under JAX transformations.
layer_cau_p primitive also enables threading keyword arguments from the
Layer's call site (like if we did
layer(x, training=True)). This allows
Layers to be implemented with keyword arguments without worrying if they are
Layers cannot be constructed directly. In fact, we override their
method to construct a
Template instead. So, despite
a subclass of
nn.Dense(20) will return an instance of
Template acts as a factory class for
Layers and using
will actually build the
Layer by calling special methods in the
Layers must implement the
spec class methods. These
Template to know the shapes and correctly initialize the parameters
initialize is responsible for parameter initialization and
spec is responsible for shape inference.
initialize must return a
LayerParams object, which represents the state
Layer, broken down into three components:
params - what we traditionally consider the "weights" of a layer, i.e.
quantities we'd like to differentiate with respect to.
state - refers to numerical values that are part of a layer that we would
not like to differentiate with respect to, such as running averages. These
quantities are automatically stop-gradiented before running the forward pass.
info - refers to metadata that is not numerical, such as configuration
Layers must implement a
_call method, which executes their forward pass.
They can refer to the
LayerParams returned in
can accept keyword arguments such as
training. We specially handle the
rng to ensure it is traced and split properly so it can be
used for stochasticity in the forward pass.
Layer wants to optionally update its state, it can implement an
method, which has the same input arguments as
_call but instead returns a
copy of the
Layer with updated state. By default, it just returns
Layer would like a custom batching rule, it can implement
_call_and_update_batched, which assumes all the input arguments have a leading
batch dimension. It must return the batched outputs and an unbatched, updated
class Layer: Base class for neural network layers.
class LayerParams: LayerParams holds params and info of Layers.
class Template: Template class used by neural network layers.