An evaluating trace that dispatches to a dynamic context.
oryx.core.interpreters.harvest.HarvestTrace(
main: MainTrace, sublevel: Sublevel
) -> None
Attributes |
level
|
|
main
|
|
sublevel
|
|
Methods
default_process_primitive
View source
default_process_primitive(
primitive: jax_core.Primitive,
tracers: List[HarvestTracer],
params: Dict[str, Any]
) -> Union[HarvestTracer, List[HarvestTracer]]
full_raise
full_raise(
val
) -> TracerType
lift
View source
lift(
val: Value
) -> HarvestTracer
post_process_call
View source
post_process_call(
call_primitive, out_tracers, params
)
post_process_custom_jvp_call
View source
post_process_custom_jvp_call(
out_tracers, jvp_was_run
)
post_process_custom_vjp_call
View source
post_process_custom_vjp_call(
out_tracers, params
)
post_process_custom_vjp_call_fwd
View source
post_process_custom_vjp_call_fwd(
out_tracers, out_trees
)
post_process_map
View source
post_process_map(
call_primitive, out_tracers, params
)
process_call
View source
process_call(
call_primitive: jax_core.Primitive,
f: Any,
tracers: List[HarvestTracer],
params: Dict[str, Any]
)
process_custom_jvp_call
View source
process_custom_jvp_call(
primitive, fun, jvp, tracers, *, symbolic_zeros
)
process_custom_transpose
process_custom_transpose(
prim, call, tracers, **params
)
process_custom_vjp_call
View source
process_custom_vjp_call(
primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros
)
process_map
View source
process_map(
call_primitive: jax_core.Primitive,
f: Any,
tracers: List[HarvestTracer],
params: Dict[str, Any]
)
process_primitive
View source
process_primitive(
primitive: jax_core.Primitive,
tracers: List[HarvestTracer],
params: Dict[str, Any]
) -> Union[HarvestTracer, List[HarvestTracer]]
pure
View source
pure(
val: Value
) -> HarvestTracer
sublift
View source
sublift(
tracer: HarvestTracer
) -> HarvestTracer