oryx.core.interpreters.unzip.UnzipTrace

Contains logic for handling UnzipTracers when tracing a function.

The UnzipTrace is very similar to jax.interpreters.partial_eval.JaxprTrace, where it adds additional recipes into the tracers that track the variables produced while tracing. Variables are defined as outputs of the variable primitive that are also tagged as "keys". Inputs to the trace are designated as keys using trace.new_arg and if all the inputs to any primitive are "keys", the outputs are also "keys".

Methods

default_process_primitive

View source

Partially evaluate primitives and saves variable recipes.

full_raise

handle_call_primitive

View source

Handler for call_primitives, like jit or layer_call.

When an UnzipTracer hits a call primitive, there is either a variable inside of the call primitive, in which case the input function needs to be unzipped into two, or there are no variables in the function, so the call_primitive is recorded in the trace as-is.

We use unzip_eval_wrapper, which returns whether or not an unzip was successful or not. If it was successful, we record two new Jaxprs into the trace (one for init, one for apply). Otherwise, we just record the Jaxpr corresponding to the function call.

Args
call_primitive a call primitive like xla_call
f a jax.linear_util wrapped function to be called
tracers inputs to the function
params parameters of the primitives
is_map whether or not the primitive is a map primitive (e.g. xla_pmap)

Returns
A list of output tracers

instantiate_const

View source

instantiate_const_abstracted

View source

lift

View source

new_arg

View source

new_const

View source

new_instantiated_const

View source

new_instantiated_literal

View source

post_process_call

View source

process_call

View source

process_custom_jvp_call

process_custom_vjp_call

process_map

View source

process_primitive

View source

pure

View source

sublift

View source

level

main

sublevel