oryx.experimental.matching.jax_rewrite.CallPrimitive

Encapsulates JAX CallPrimitives like jit and pmap.

Inherits From: JaxExpression, Expression, Pattern

primitive A JAX call primitive.
operands A sequence of expressions that are evaluated and passed as inputs to the primitive when the CallPrimitive is evaluated.
expression The expression that corresponds to the body of the call primitive. The operands are bound to the variable_names in an environment and the expression is evaluated with that environment.
params A Params object corresponding to the parameters of the call primitive.
variable_names A sequence of string names that are used as keys for the operands in the environment expression is evaluated in.
dtype

shape

Methods

evaluate

View source

match

View source

tree_children

View source

tree_map

View source

__eq__