View source on GitHub |
Represents JAX expressions with closed over constants.
Inherits From: JaxExpression
, Expression
, Pattern
oryx.experimental.matching.jax_rewrite.BoundExpression(
expressions: Sequence[Expr],
consts: oryx.experimental.matching.jax_rewrite.Bindings
)
A BoundExpression
enables pinning JaxVar
s in an expression to fixed
values, removing the need to bind them to values when the BoundExpression
is
evaluated. Conceptually this is equivalent to a jax.core.ClosedJaxpr
.
Methods
evaluate
evaluate(
env: oryx.experimental.matching.jax_rewrite.Bindings
) -> Any
Evaluates using an environment augmented with constants.
match
match(
expr: Expr,
bindings: oryx.experimental.matching.jax_rewrite.Bindings
,
succeed: oryx.experimental.matching.jax_rewrite.Continuation
) -> oryx.experimental.matching.jax_rewrite.Success
tree_children
tree_children() -> Iterator[Expr]
tree_map
tree_map(
fn
) -> 'BoundExpression'
__eq__
__eq__(
other
)