oryx.core.trace_util.stage

Returns a function that stages a function to a TypedJaxpr and its Pytrees.