View source on GitHub
|
Module for single-dispatch functions for handling state.
This module defines single-dispatch functions that are used to construct
and use Modules. The main functions are init, call_and_update and spec.
They are all single-dispatch functions, meaning they have specific implementations depending on the type of their first inputs. These implementations can be provided from outside of this library, so they act as a general API for handling state.
Methods
init
init converts an input object into an "initializer" function, i.e. one that
takes in a random PRNGKey and a set of inputs and returns a Module.
function.py registers Python functions with this transformations and another
potential application is neural network layers.
call_and_update
call_and_update executes the computation associated with an input
object, returning the output and a copy of the object with updated state.
For example, for a Module, call_and_update(module, ...)
runs module.call_and_update but this behavior could be defined for arbitrary
objects. For example in registrations.py we provide some default registrations
for various Python data structures like lists and tuples.
We also provide a call and update function which are wrappers around
call_and_update.
spec
spec has the same API as init without the PRNGKey and returns the shape
of the output that would result from calling the input object.
Example:
def f(x, init_key=None):
w = module.variable(random.normal(init_key, x.shape), name='w')
w = module.assign(w + 1., name='w')
return np.dot(w, x)
api.spec(f)(random.PRNGKey(0), np.ones(5)) # ==> ArraySpec((), np.float32)
m = api.init(f)(random.PRNGKey(0), np.ones(5))
m.variables() # ==> {'w': ...}
output, new_module = api.call_and_update(m, np.ones(5))
Classes
class ArraySpec: Encapsulates shape and dtype of an abstract array.
Functions
init(...): Transforms an object into a function that initializes a module.
spec(...): A general purpose transformation for getting the output shape and dtype of an object.
View source on GitHub