Module: oryx.core.trace_util

Module for JAX tracing utility functions.

Functions

get_dynamic_context(...): Returns the current active dynamic context for a trace.

get_shaped_aval(...): Converts a JAX value type into a shaped abstract value.

new_dynamic_context(...): Creates a dynamic context for a trace.

pv_like(...): Converts a JAX value type into a JAX PartialVal.

stage(...): Returns a function that stages a function to a ClosedJaxper.

trees(...): Returns a function that determines input and output pytrees from inputs.