Oryx ist eine Bibliothek für probabilistische Programmierung und Deep Learning, die auf JAX aufbaut.

import oryx
import jax.numpy as jnp
ppl = oryx.core.ppl
tfd = oryx.distributions

# Define sampling function
def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.

# Transform sampling function into a log-density function
ppl.log_prob(sample)(1.)  # ==> -0.9189
Oryx 'Ansatz besteht darin, eine Reihe von Funktionstransformationen verfügbar zu machen, die zusammengesetzt und in die vorhandenen Transformationen von JAX integriert werden. Um Oryx zu installieren, können Sie Folgendes ausführen:
 pip install --upgrade oryx