oryx.experimental.nn.Flatten

Flattens the inputs collapsing all ending dimensions.

Inherits From: Layer, Module, Pytree

info Returns the info for this Layer.
params Returns the parameters of this Layer.
state Returns the state of this Layer.

Methods

call

View source

Calls the Layer's call_and_update and returns the first result.

call_and_update

View source

Uses the layer_cau primitive to call `self._call_and_update.

flatten

View source

Converts the Layer to a tuple suitable for PyTree.

initialize

View source

Initializes Flatten Layer.

Args
rng Random key.
in_spec Input Spec.

Returns
Tuple with the output shape and the LayerParams.

new

View source

Creates Layer given a LayerParams namedtuple.

Args
layer_params LayerParams namedtuple that defines the Layer.
name a string name for the Layer.

Returns
A Layer object.

replace

View source

Returns a copy of the layer with replaced properties.

spec

View source

unflatten

View source

Reconstruct the Layer from a flattened version.

update

View source

Calls the Layer's call_and_update and returns the second result.

variables

View source

Returns the variables dictionary for this Layer.

__call__

View source

Emulates a regular function call.

A Module's dunder call will ensure state is updated after the function call by calling assign on the updated state before returning the output of the function.

Args
*args The arguments to the module.
**kwargs The keyword arguments to the module.

Returns
The output of the module.