View source on GitHub |
Encapsulates JAX CallPrimitive
s like jit
and pmap
.
Inherits From: JaxExpression
, Expression
, Pattern
oryx.experimental.matching.jax_rewrite.CallPrimitive(
primitive: jax_core.Primitive,
operands: Sequence[Any],
expression: Any,
params: oryx.experimental.matching.jax_rewrite.Params
,
variable_names: Sequence[str]
)
Methods
evaluate
evaluate(
env: oryx.experimental.matching.jax_rewrite.Bindings
) -> Any
match
match(
expr: Expr,
bindings: oryx.experimental.matching.jax_rewrite.Bindings
,
succeed: oryx.experimental.matching.jax_rewrite.Continuation
) -> oryx.experimental.matching.jax_rewrite.Success
tree_children
tree_children() -> Iterator[Expr]
tree_map
tree_map(
fn
) -> 'CallPrimitive'
__eq__
__eq__(
other
)