Attend the Women in ML Symposium on December 7 Register now
透過集合功能整理內容 你可以依據偏好儲存及分類內容。

Oryx 是以 JAX 為基礎打造的程式庫,專門用於機率程式設計和深度學習。

import oryx
import jax.numpy as jnp
ppl = oryx.core.ppl
tfd = oryx.distributions

# Define sampling function
def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.

# Transform sampling function into a log-density function
ppl.log_prob(sample)(1.)  # ==> -0.9189
Oryx 的做法是公開可撰寫及整合 JAX 現有轉換的一組函式轉換。如要安裝 Oryx,可以執行下列指令:
 pip install --upgrade oryx