Avoir une question? Connectez-vous avec la communauté sur le forum TensorFlow Visiter le forum

Oryx est une bibliothèque basée sur JAX, conçue pour la programmation probabiliste et le deep learning.

import oryx
import jax.numpy as jnp
ppl = oryx.core.ppl
tfd = oryx.distributions

# Define sampling function
def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.

# Transform sampling function into a log-density function
ppl.log_prob(sample)(1.)  # ==> -0.9189
L'approche d'Oryx consiste à exposer un ensemble de transformations de fonction qui composent les transformations actuelles de JAX et s'intègrent dans celles-ci. Pour installer Oryx, vous pouvez exécuter la commande suivante :
 pip install --upgrade oryx