Hari Komunitas ML adalah 9 November! Bergabung dengan kami untuk update dari TensorFlow, JAX, dan lebih Pelajari lebih lanjut

Oryx adalah pustaka untuk pemrograman probabilistik dan pembelajaran mendalam yang dibangun di atas JAX.

import oryx
import jax.numpy as jnp
ppl = oryx.core.ppl
tfd = oryx.distributions

# Define sampling function
def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.

# Transform sampling function into a log-density function
ppl.log_prob(sample)(1.)  # ==> -0.9189
Pendekatan Oryx adalah untuk mengekspos sekumpulan transformasi fungsi yang menyusun dan terintegrasi dengan transformasi JAX yang ada. Untuk menginstal Oryx, Anda dapat menjalankan:
 pip install --upgrade oryx