View source on GitHub |
A primitive that appears in traces through transformations.
oryx.core.HigherOrderPrimitive(
name
)
In JAX, when functions composed of primitives are traced,
only the primitives appear in the trace. A HigherOrderPrimitive (HOP)
can be bound to a function using call_bind
, which
traces the function and surfaces its Jaxpr
in the trace in the HOP's params.
A HOP appears in the traces of transformed functions. Specifically,
unlike jax.custom_transforms
functions, which do not
appear in a trace after a transformation like jax.grad
or jax.vmap
is applied, a HOP will create another HOP to appear in the trace
after transformation, bound to the transformed function.
Methods
abstract_eval
abstract_eval(
*args, **params
)
bind
bind(
fun, *args, **params
)
bind_with_trace
bind_with_trace(
trace, args, params
)
def_abstract_eval
def_abstract_eval(
abstract_eval
)
def_custom_bind
def_custom_bind(
bind
)
def_effectful_abstract_eval
def_effectful_abstract_eval(
effectful_abstract_eval
)
def_impl
def_impl(
impl
)
get_bind_params
get_bind_params(
params
)
impl
impl(
f, *args, **params
)
subcall
subcall(
name
)
Class Variables | |
---|---|
call_primitive |
True
|
map_primitive |
False
|
multiple_results |
True
|