View source on GitHub |
Used to select the outputs of JAX primitives with multiple outputs.
Inherits From: JaxExpression
, Expression
, Pattern
oryx.experimental.matching.jax_rewrite.Part(
operand: Expr, index: int
)
When a JAX primitive has multiple_results = True
, it returns several outputs
when called. To represent multiple outputs in an expression tree, we wrap
the output of a multiple-output primitive with Part
with an index for each
of its outputs. Part
is primarily used with CallPrimitive
s, which always
have multiple outputs.
Attributes | |
---|---|
operand
|
An expression that can be indexed into with an integer
i.e. operand[i] .
|
index
|
The index that is used when accessing the operand. |
dtype
|
|
shape
|
Methods
evaluate
evaluate(
env: oryx.experimental.matching.jax_rewrite.Bindings
) -> Any
match
match(
expr: Expr,
bindings: oryx.experimental.matching.jax_rewrite.Bindings
,
succeed: oryx.experimental.matching.jax_rewrite.Continuation
) -> oryx.experimental.matching.jax_rewrite.Success
tree_children
tree_children() -> Iterator[Expr]
tree_map
tree_map(
fn
) -> 'Part'
__eq__
__eq__(
other
)