oryx.experimental.nn.Layer

Base class for neural network layers.

Inherits From: Module, Pytree

A Layer is a subclass of Module with some additional functionality. Like Modules, Layers have a variables() method that returns a dictionary mapping names to state values. It also has a call_and_update function that returns the output of a computation and a new Layer with updated state. Underneath the hood, Layers do a couple extra things beyond Modules.

info Returns the info for this Layer.
params Returns the parameters of this Layer.
state Returns the state of this Layer.

Methods

call

View source

Calls the Layer's call_and_update and returns the first result.

call_and_update

View source

Uses the layer_cau primitive to call `self._call_and_update.

flatten

View source

Converts the Layer to a tuple suitable for PyTree.

initialize

View source

Initializes a Layer from an init_key and input specification.

new

View source

Creates Layer given a LayerParams namedtuple.

Args
layer_params LayerParams namedtuple that defines the Layer.
name a string name for the Layer.

Returns
A Layer object.

replace

View source

Returns a copy of the layer with replaced properties.

unflatten

View source

Reconstruct the Layer from a flattened version.

update

View source

Calls the Layer's call_and_update and returns the second result.

variables

View source

Returns the variables dictionary for this Layer.

__call__

View source

Emulates a regular function call.

A Module's dunder call will ensure state is updated after the function call by calling assign on the updated state before returning the output of the function.

Args
*args The arguments to the module.
**kwargs The keyword arguments to the module.

Returns
The output of the module.