Oryx ist eine Bibliothek für probabilistische Programmierung und Deep Learning, die auf JAX aufbaut.
import oryx import jax.numpy as jnp ppl = oryx.core.ppl tfd = oryx.distributions # Define sampling function def sample(key): x = ppl.random_variable(tfd.Normal(0., 1.))(key) return jnp.exp(x / 2.) + 2. # Transform sampling function into a log-density function ppl.log_prob(sample)(1.) # ==> -0.9189
Oryx 'Ansatz besteht darin, eine Reihe von Funktionstransformationen verfügbar zu machen, die zusammengesetzt und in die vorhandenen Transformationen von JAX integriert werden. Um Oryx zu installieren, können Sie Folgendes ausführen:
pip install --upgrade oryx