View source on GitHub
|
Contains the Template and Layer API for Oryx.
Modules are an abstraction provided by Oryx that enable encapsulating both
state and functionality, and some basic neural network layers could be
implemented with the stateful function API. However, we want a few extra pieces
of functionality that important to writing practical neural networks beyond
what exists in Module:
1) Handling keyword arguments. It is common to have neural networks have behavior conditional on a flag, like whether or not we are training the neural network. We can handle this using a custom Jaxpr interpreter.
2) Custom batch semantics. We can't implement a layer like batch normalization without having custom behavior for batches. We can accomplish this by using a batch rule for a custom JAX primitive associated with layers.
3) Special combinators. Building a custom Module abstraction enables
overloading operators like >> to build complex architectures, along with
handling explicit threading of PRNGKeys.
We implement these additions with the Template and Layer classes.
Template
A Template is an object registered with core.state.init that can
be initialized into a Layer. For example, for the template nn.Dense(20)
can be initialized into a Layer by calling
core.state.init(nn.Dense(20))(random.PRNGKey(...), ...), just like a
stateful function. In most ways, Templates behave like stateful functions,
in that you can call them, i.e. nn.Dense(20)(x, init_key=...) and it will
execute a dense layer initialized with init_key. However, Templates have
some extra functionality. We can use the >> operator to chain Templates
together, i.e. nn.Dense(200) >> nn.Relu() is a new Template that composes
a dense layer with a ReLU activation. It will appropriately split and thread
the init_key to all its inner Templates when initialized.
Layer
A Layer is a subclass of Module with some additional functionality. Like
Modules, Layers have a variables() method that returns a dictionary
mapping names to state values. It also has a call_and_update function that
returns the output of a computation and a new Layer with updated state.
Underneath the hood, Layers do a couple extra things beyond Modules.
The first is that a Layer's call_and_update is associated with a JAX
primitive: layer_cau_p. This primitive serves multiple purposes. The first is
that it enables custom behavior when being vmap-ed. For example, a
BatchNorm layer has different behavior when being vmap-ed over many
examples vs. a single example. When a Layer implements a
_call_and_update_batched method, the layer_cau_p primitive knows to use that
method instead of mapping over the default _call_and_update method. This
enables an inheritance pattern for custom behavior under JAX transformations.
The layer_cau_p primitive also enables threading keyword arguments from the
Layer's call site (like if we did layer(x, training=True)). This allows
Layers to be implemented with keyword arguments without worrying if they are
being traced.
Templates vs Layers
Layers cannot be constructed directly. In fact, we override their __new__
method to construct a Template instead. So, despite nn.Dense being
a subclass of nn.Layer, nn.Dense(20) will return an instance of Template.
The Template acts as a factory class for Layers and using core.state.init
will actually build the Layer by calling special methods in the Layer.
initialize/spec
All Layers must implement the initialize and spec class methods. These
enable a Template to know the shapes and correctly initialize the parameters
in a Layer. initialize is responsible for parameter initialization and
spec is responsible for shape inference.
initialize must return a LayerParams object, which represents the state
of a Layer, broken down into three components:
1) params - what we traditionally consider the "weights" of a layer, i.e.
quantities we'd like to differentiate with respect to.
2) state - refers to numerical values that are part of a layer that we would
not like to differentiate with respect to, such as running averages. These
quantities are automatically stop-gradiented before running the forward pass.
2) info - refers to metadata that is not numerical, such as configuration
strings.
_call
All Layers must implement a _call method, which executes their forward pass.
They can refer to the LayerParams returned in initialize. The _call method
can accept keyword arguments such as training. We specially handle the
keyword argument rng to ensure it is traced and split properly so it can be
used for stochasticity in the forward pass.
_update
If a Layer wants to optionally update its state, it can implement an _update
method, which has the same input arguments as _call but instead returns a
copy of the Layer with updated state. By default, it just returns self.
_call_and_update_batched
If a Layer would like a custom batching rule, it can implement
_call_and_update_batched, which assumes all the input arguments have a leading
batch dimension. It must return the batched outputs and an unbatched, updated
Layer.
Classes
class Layer: Base class for neural network layers.
class LayerParams: LayerParams holds params and info of Layers.
class Template: Template class used by neural network layers.
View source on GitHub