Propagates cells in a Jaxpr using a set of rules.
oryx.core.interpreters.propagate.propagate(
cell_type: Type[oryx.core.interpreters.propagate.Cell
],
rules: Dict[jax_core.Primitive, PropagationRule],
jaxpr: pe.Jaxpr,
constcells: List[oryx.core.interpreters.propagate.Cell
],
incells: List[oryx.core.interpreters.propagate.Cell
],
outcells: List[oryx.core.interpreters.propagate.Cell
]
) -> oryx.core.interpreters.propagate.Environment
Args |
cell_type
|
used to instantiate literals into cells
|
rules
|
maps JAX primitives to propagation rule functions
|
jaxpr
|
used to construct the propagation graph
|
constcells
|
used to populate the Jaxpr's constvars
|
incells
|
used to populate the Jaxpr's invars
|
outcells
|
used to populate the Jaxpr's outcells
|
Returns |
The Jaxpr environment after propagation has terminated
|