Oryx là một thư viện dành cho lập trình xác suất và học sâu được xây dựng dựa trên JAX.

import oryx
import jax.numpy as jnp
ppl = oryx.core.ppl
tfd = oryx.distributions

# Define sampling function
def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.

# Transform sampling function into a log-density function
ppl.log_prob(sample)(1.)  # ==> -0.9189
Cách tiếp cận của Oryx là đưa ra một tập hợp các phép biến đổi hàm tạo và tích hợp với các phép biến đổi hiện có của JAX. Để cài đặt Oryx, bạn có thể chạy:
 pip install --upgrade oryx