![]() |
Contains logic for handling UnzipTracers when tracing a function.
oryx.core.interpreters.unzip.UnzipTrace(
main: "MainTrace",
sublevel: "Sublevel"
) -> None
The UnzipTrace is very similar to jax.interpreters.partial_eval.JaxprTrace,
where it adds additional recipes into the tracers that track the variables
produced while tracing. Variables are defined as outputs of the variable
primitive that are also tagged as "keys". Inputs to the trace are designated
as keys using trace.new_arg
and if all the inputs to any primitive are
"keys", the outputs are also "keys".
Methods
default_process_primitive
default_process_primitive(
primitive, tracers, params
)
Partially evaluate primitives and saves variable recipes.
full_raise
full_raise(
val
) -> "Tracer"
handle_call_primitive
handle_call_primitive(
call_primitive, f, tracers, params, is_map
)
Handler for call_primitives, like jit or layer_call.
When an UnzipTracer hits a call primitive, there is either a variable inside of the call primitive, in which case the input function needs to be unzipped into two, or there are no variables in the function, so the call_primitive is recorded in the trace as-is.
We use unzip_eval_wrapper
, which returns whether or not an unzip
was successful or not. If it was successful, we record two new
Jaxprs into the trace (one for init, one for apply). Otherwise, we
just record the Jaxpr corresponding to the function call.
Args | |
---|---|
call_primitive
|
a call primitive like xla_call |
f
|
a jax.linear_util wrapped function to be called |
tracers
|
inputs to the function |
params
|
parameters of the primitives |
is_map
|
whether or not the primitive is a map primitive (e.g. xla_pmap) |
Returns | |
---|---|
A list of output tracers |
instantiate_const
instantiate_const(
tracer
)
instantiate_const_abstracted
instantiate_const_abstracted(
tracer
)
lift
lift(
val
)
new_arg
new_arg(
pval, key
)
new_const
new_const(
val
)
new_instantiated_const
new_instantiated_const(
val
)
new_instantiated_literal
new_instantiated_literal(
val
)
post_process_call
post_process_call(
call_primitive, out_tracers, params
)
process_call
process_call(
call_primitive, f, tracers, params
)
process_custom_jvp_call
process_custom_jvp_call(
primitive, fun, jvp, tracers
)
process_custom_vjp_call
process_custom_vjp_call(
primitive, fun, fwd, bwd, tracers, out_trees
)
process_map
process_map(
call_primitive, f, tracers, params
)
process_primitive
process_primitive(
primitive, tracers, params
)
pure
pure(
val
)
sublift
sublift(
val
)
Class Variables | |
---|---|
level | |
main | |
sublevel |