![]() |
Contains the Template
and Layer
API for Oryx.
Module
s are an abstraction provided by Oryx that enable encapsulating both
state and functionality, and some basic neural network layers could be
implemented with the stateful function API. However, we want a few extra pieces
of functionality that important to writing practical neural networks beyond
what exists in Module
:
1) Handling keyword arguments. It is common to have neural networks have behavior conditional on a flag, like whether or not we are training the neural network. We can handle this using a custom Jaxpr interpreter.
2) Custom batch semantics. We can't implement a layer like batch normalization without having custom behavior for batches. We can accomplish this by using a batch rule for a custom JAX primitive associated with layers.
3) Special combinators. Building a custom Module
abstraction enables
overloading operators like >>
to build complex architectures, along with
handling explicit threading of PRNGKey
s.
We implement these additions with the Template
and Layer
classes.
Template
A Template
is an object registered with core.state.init
that can
be initialized into a Layer
. For example, for the template nn.Dense(20)
can be initialized into a Layer
by calling
core.state.init(nn.Dense(20))(random.PRNGKey(...), ...)
, just like a
stateful function. In most ways, Template
s behave like stateful functions,
in that you can call them, i.e. nn.Dense(20)(x, init_key=...)
and it will
execute a dense layer initialized with init_key
. However, Template
s have
some extra functionality. We can use the >>
operator to chain Template
s
together, i.e. nn.Dense(200) >> nn.Relu()
is a new Template
that composes
a dense layer with a ReLU activation. It will appropriately split and thread
the init_key
to all its inner Template
s when initialized.
Layer
A Layer
is a subclass of Module
with some additional functionality. Like
Module
s, Layer
s have a variables()
method that returns a dictionary
mapping names to state values. It also has a call_and_update
function that
returns the output of a computation and a new Layer
with updated state.
Underneath the hood, Layers
do a couple extra things beyond Module
s.
The first is that a Layer
's call_and_update
is associated with a JAX
primitive: layer_cau_p
. This primitive serves multiple purposes. The first is
that it enables custom behavior when being vmap
-ed. For example, a
BatchNorm
layer has different behavior when being vmap
-ed over many
examples vs. a single example. When a Layer
implements a
_call_and_update_batched
method, the layer_cau_p
primitive knows to use that
method instead of mapping over the default _call_and_update
method. This
enables an inheritance pattern for custom behavior under JAX transformations.
The layer_cau_p
primitive also enables threading keyword arguments from the
Layer
's call site (like if we did layer(x, training=True)
). This allows
Layer
s to be implemented with keyword arguments without worrying if they are
being traced.
Template
s vs Layer
s
Layer
s cannot be constructed directly. In fact, we override their __new__
method to construct a Template
instead. So, despite nn.Dense
being
a subclass of nn.Layer
, nn.Dense(20)
will return an instance of Template
.
The Template
acts as a factory class for Layer
s and using core.state.init
will actually build the Layer
by calling special methods in the Layer
.
initialize
/spec
All Layer
s must implement the initialize
and spec
class methods. These
enable a Template
to know the shapes and correctly initialize the parameters
in a Layer
. initialize
is responsible for parameter initialization and
spec
is responsible for shape inference.
initialize
must return a LayerParams
object, which represents the state
of a Layer
, broken down into three components:
1) params
- what we traditionally consider the "weights" of a layer, i.e.
quantities we'd like to differentiate with respect to.
2) state
- refers to numerical values that are part of a layer that we would
not like to differentiate with respect to, such as running averages. These
quantities are automatically stop-gradiented before running the forward pass.
2) info
- refers to metadata that is not numerical, such as configuration
strings.
_call
All Layer
s must implement a _call
method, which executes their forward pass.
They can refer to the LayerParams
returned in initialize
. The _call
method
can accept keyword arguments such as training
. We specially handle the
keyword argument rng
to ensure it is traced and split properly so it can be
used for stochasticity in the forward pass.
_update
If a Layer
wants to optionally update its state, it can implement an _update
method, which has the same input arguments as _call
but instead returns a
copy of the Layer
with updated state. By default, it just returns self
.
_call_and_update_batched
If a Layer
would like a custom batching rule, it can implement
_call_and_update_batched
, which assumes all the input arguments have a leading
batch dimension. It must return the batched outputs and an unbatched, updated
Layer
.
Classes
class Layer
: Base class for neural network layers.
class LayerParams
: LayerParams holds params and info of Layers.
class Template
: Template class used by neural network layers.