oryx.core.HigherOrderPrimitive

A primitive that appears in traces through transformations.

In JAX, when functions composed of primitives are traced, only the primitives appear in the trace. A HigherOrderPrimitive (HOP) can be bound to a function using call_bind, which traces the function and surfaces its Jaxpr in the trace in the HOP's params.

A HOP appears in the traces of transformed functions. Specifically, unlike jax.custom_transforms functions, which do not appear in a trace after a transformation like jax.grad or jax.vmap is applied, a HOP will create another HOP to appear in the trace after transformation, bound to the transformed function.

Methods

abstract_eval

bind

bind_with_trace

def_abstract_eval

def_custom_bind

def_effectful_abstract_eval

def_impl

get_bind_params

impl

View source

subcall

View source

call_primitive True
map_primitive False
multiple_results True