Have a question? Connect with the community at the TensorFlow Forum Visit Forum

Module: oryx.core

Contains Oryx's core transformations and functionality.


interpreters module: Contains function transformations implemented using JAX tracing machinery.

ppl module: Module for probabilistic programming features.

primitive module: Module for higher order primitives.

pytree module: Contains the Pytree class.

serialize module: Contains logic for serializing and deserializing PytreeTypes.

state module: Module for stateful functions.

trace_util module: Module for JAX tracing utility functions.


class FlatPrimitive: Contains default implementations of transformations.

class HigherOrderPrimitive: A primitive that appears in traces through transformations.

class NonInvertibleError: Raised by a custom inverse definition when values are unknown.

class Pytree: Class that registers objects as Jax pytree_nodes.


call_bind(...): Binds a primitive to a function call.

custom_inverse(...): Decorates a function to enable defining a custom inverse.

harvest(...): Transforms a function into a "functionalized" version.



inverse_and_ildj(...): Inverse and ILDJ function transformation.

log_prob(...): LogProb function transformation.

nest(...): Wraps a function to create a new scope for harvested values.

plant(...): Injects tagged values into a function.

reap(...): Collects tagged values from a function.

sow(...): Marks a value with a name and a tag.

tie_all(...): An identity function that ties arguments together in a JAX trace.

tie_in(...): A reimplementation of jax.tie_in that handles pytrees.

unzip(...): Unzip function transformation.

ildj_registry Instance of oryx.core.interpreters.inverse.core.InverseDict

 random_variable: <function random_variable_log_prob at 0x7f3544802710>


 sow: <function sow_unzip at 0x7f3442e934d0>