View source on GitHub
|
Used to select the outputs of JAX primitives with multiple outputs.
Inherits From: JaxExpression, Expression, Pattern
oryx.experimental.matching.jax_rewrite.Part(
operand: Expr, index: int
)
When a JAX primitive has multiple_results = True, it returns several outputs
when called. To represent multiple outputs in an expression tree, we wrap
the output of a multiple-output primitive with Part with an index for each
of its outputs. Part is primarily used with CallPrimitives, which always
have multiple outputs.
Attributes | |
|---|---|
operand
|
An expression that can be indexed into with an integer
i.e. operand[i].
|
index
|
The index that is used when accessing the operand. |
dtype
|
|
shape
|
|
Methods
evaluate
evaluate(
env: oryx.experimental.matching.jax_rewrite.Bindings
) -> Any
match
match(
expr: Expr,
bindings: oryx.experimental.matching.jax_rewrite.Bindings,
succeed: oryx.experimental.matching.jax_rewrite.Continuation
) -> oryx.experimental.matching.jax_rewrite.Success
tree_children
tree_children() -> Iterator[Expr]
tree_map
tree_map(
fn
) -> 'Part'
__eq__
__eq__(
other
)
View source on GitHub