oryx.experimental.matching.jax_rewrite.rewrite

Rewrites a JAX function according to a rewrite rule.

f A function to be transformed.
rule A function that transforms a rules.Expression into another.

A function that when called with the original arguments to f executes the body of f rewritten according to the provided rule.