Propagates cells in a Jaxpr using a set of rules.
oryx.core.interpreters.propagate.propagate(
cell_type: Type[oryx.core.interpreters.propagate.Cell
],
rules: Dict[jax_core.Primitive, PropagationRule],
jaxpr: pe.Jaxpr,
constcells: List[oryx.core.interpreters.propagate.Cell
],
incells: List[oryx.core.interpreters.propagate.Cell
],
outcells: List[oryx.core.interpreters.propagate.Cell
],
reducer: Callable[[Environment, Equation, State, State], State] = identity_reducer,
initial_state: State = None
) -> Tuple[Environment, State]
Args |
cell_type
|
used to instantiate literals into cells
|
rules
|
maps JAX primitives to propagation rule functions
|
jaxpr
|
used to construct the propagation graph
|
constcells
|
used to populate the Jaxpr's constvars
|
incells
|
used to populate the Jaxpr's invars
|
outcells
|
used to populate the Jaxpr's outcells
|
reducer
|
An optional callable used to reduce over the state at each
equation in the Jaxpr. reducer takes in (env, eqn, state, new_state)
as arguments and should return an updated state. The new_state value
is provided by each equation.
|
initial_state
|
The initial state value used in the reducer
|
Returns |
The Jaxpr environment after propagation has terminated
|