View source on GitHub |
Module for higher order primitives.
Classes
class FlatPrimitive
: Contains default implementations of transformations.
class HigherOrderPrimitive
: A primitive that appears in traces through transformations.
Functions
call_bind(...)
: Binds a primitive to a function call.
tie_all(...)
: An identity function that ties arguments together in a JAX trace.
tie_in(...)
: A reimplementation of jax.tie_in
that handles pytrees.