Probabilistic Programming in Oryx

View on Run in Google Colab View source on GitHub Download notebook
pip install -q -U jax jaxlib
pip install -q -Uq oryx -I
pip install -q tfp-nightly --upgrade
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from jax import random

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import oryx

Probabilistic programming is the idea that we can express probabilistic models using features from a programming language. Tasks like Bayesian inference or marginalization are then provided as language features and can potentially be automated.

Oryx provides a probabilistic programming system in which probabilistic programs are just expressed as Python functions; these programs are then transformed via composable function transformations like those in JAX! The idea is to start with simple programs (like sampling from a random normal) and compose them together to form models (like a Bayesian neural network). An important point of Oryx's PPL design is to enable programs to look like functions you'd already write and use in JAX, but are annotated to make transformations aware of them.

Let's first import Oryx's core PPL functionality.

from oryx.core.ppl import random_variable
from oryx.core.ppl import log_prob
from oryx.core.ppl import joint_sample
from oryx.core.ppl import joint_log_prob
from oryx.core.ppl import block
from oryx.core.ppl import intervene
from oryx.core.ppl import conditional
from oryx.core.ppl import graph_replace
from oryx.core.ppl import nest

What are probabilistic programs in Oryx?

In Oryx, probabilistic programs are just pure Python functions that operate on JAX values and pseudorandom keys and return a random sample. By design, they are compatible with transformations like jit and vmap. However, the Oryx probabilistic programming system provides tools that enable you to annotate your functions in useful ways.

Following the JAX philosophy of pure functions, an Oryx probabilistic program is a Python function that takes a JAX PRNGKey as its first argument and any number of subsequent conditioning arguments. The output of the function is called a "sample" and the same restrictions that apply to jit-ed and vmap-ed functions apply to probabilistic programs (e.g. no data-dependent control flow, no side effects, etc.). This differs from many imperative probabilistic programming systems in which a 'sample' is the entire execution trace, including values internal to the program's execution. We will see later how Oryx can access internal values using the joint_sample, discussed below.

Program :: PRNGKey -> ... -> Sample

Here is a "hello world" program that samples from a log-normal distribution.

def log_normal(key):
  return jnp.exp(random_variable(tfd.Normal(0., 1.))(key))

sns.distplot(jit(vmap(log_normal))(random.split(random.PRNGKey(0), 10000)))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/ FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)


The log_normal function is a thin wrapper around a Tensorflow Probability (TFP) distribution, but instead of calling tfd.Normal(0., 1.).sample, we've used random_variable instead. As we'll see later, random_variable enables us to convert objects into probabilistic programs, along with other useful functionality.

We can convert log_normal into a log-density function using the log_prob transformation:

x = jnp.linspace(0., 5., 1000)
plt.plot(x, jnp.exp(vmap(log_prob(log_normal))(x)))


Because we've annotated the function with random_variable, log_prob is aware that there was a call to tfd.Normal(0., 1.).sample and uses tfd.Normal(0., 1.).log_prob to compute the base distribution log prob. To handle the jnp.exp, ppl.log_prob automatically computes densities through bijective functions, keeping track of volume changes in the change-of-variable computation.

In Oryx, we can take programs and transform them using function transformations -- for example, jax.jit or log_prob. Oryx can't do this with just any program though; it requires sampling functions that have registered their log density function with Oryx. Fortunately, Oryx automatically registers TensorFlow Probability (TFP) distributions in its system.

Oryx's probabilistic programming tools

Oryx has several function transformations geared towards probabilistic programming. We'll go over most of them and provide some examples. At the end, we'll put it all together into an MCMC case study. You can also refer to the documentation for core.ppl.transformations for more details.


random_variable has two main pieces of functionality, both focused on annotating Python functions with information that can be used in transformations.

  1. random_variable' operates as the identity function by default, but can use type-specific registrations to convert objects into probabilistic programs.`

    For callable types (Python functions, lambdas, functools.partials, etc.) and arbitrary objects (like JAX DeviceArrays) it will just return its input.

    random_variable(x: object) == x
    random_variable(f: Callable[...]) == f

    Oryx automatically registers TensorFlow Probability (TFP) distributions, which are converted into probabilistic programs that call the distribution's sample method.

    random_variable(tfd.Normal(0., 1.))(random.PRNGKey(0)) # ==> -0.20584235

    Oryx additionally embeds information about the TFP distribution into JAX traces that enables automatically computing log densities.

  2. random_variable can tag values with names, making them useful for downstream transformations, by providing an optional name keyword argument to random_variable. When we pass an array into random_variable along with a name (e.g. random_variable(x, name='x')), it just tags the value and returns it. If we pass in a callable or TFP distribution, random_variable returns a program that tags its output sample with name.

These annotations do not change the semantics of the program when executed, but only when transformed (i.e. the program will return the same value with or without the use of random_variable).

Let's go over an example where we use both pieces of functionality together.

def latent_normal(key):
  z_key, x_key = random.split(key)
  z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)

In this program we've tagged the intermediates z and x, which makes the transformations joint_sample, intervene, conditional and graph_replace aware of the names 'z' and 'x'. We'll go over exactly how each transformation uses names later.


The log_prob function transformation converts an Oryx probabilistic program into its log-density function. This log-density function takes a potential sample from the program as input and returns its log-density under the underlying sampling distribution.

log_prob :: Program -> (Sample -> LogDensity)

Like random_variable, it works via a registry of types where TFP distributions are automatically registered, so log_prob(tfd.Normal(0., 1.)) calls tfd.Normal(0., 1.).log_prob. For Python functions, however, log_prob traces the program using JAX and looks for sampling statements. The log_prob transformation works on most programs that return random variables, directly or via invertible transformations but not on programs that sample values internally that aren't returned. If it cannot invert the necessary operations in the program, log_prob will throw an error.

Here are some examples of log_prob applied to various programs.

  1. log_prob works on programs that directly sample from TFP distributions (or other registered types) and return their values.
def normal(key):
  return random_variable(tfd.Normal(0., 1.))(key)
  1. log_prob is able to compute log-densities of samples from programs that transform random variates using bijective functions (e.g jnp.exp, jnp.tanh, jnp.split).
def log_normal(key):
  return 2 * jnp.exp(random_variable(tfd.Normal(0., 1.))(key))

In order to compute a sample from log_normal's log-density, we first need to invert the exp, taking the log of the sample, and then add a volume-change correction using the inverse log-det Jacobian of exp (see the change of variable formula from Wikipedia).

  1. log_prob works with programs that output structures of samples like, Python dictionaries or tuples.
def normal_2d(key):
  x = random_variable(
    tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)))(key)
  x1, x2 = jnp.split(x, 2, 0)
  return dict(x1=x1, x2=x2)
sample = normal_2d(random.PRNGKey(0))
{'x1': DeviceArray([-0.7847661], dtype=float32), 'x2': DeviceArray([0.8564447], dtype=float32)}
  1. log_prob walks the traced computation graph of the function, computing both forward and inverse values (and their log-det Jacobians) when necessary in an attempt to connect returned values with their base sampled values via a well-defined change of variables. Take the following example program:
def complex_program(key):
  k1, k2 = random.split(key)
  z = random_variable(tfd.Normal(0., 1.))(k1)
  x = random_variable(tfd.Normal(jax.nn.relu(z), 1.))(k2)
  return jnp.exp(z), jax.nn.sigmoid(x)
sample = complex_program(random.PRNGKey(0))
(DeviceArray(1.1547576, dtype=float32), DeviceArray(0.24830955, dtype=float32))

In this program, we sample x conditionally on z, meaning we need the value of z before we can compute the log-density of x. However, in order to compute z, we first have to invert the jnp.exp applied to z. Thus, in order to compute the log-densities of x and z, log_prob needs to first invert the first output, and then pass it forward through the jax.nn.relu to compute the mean of p(x | z).

For more information about log_prob, you can refer to core.interpreters.log_prob. In implementation, log_prob is closely based off of the inverse JAX transformation; to learn more about inverse, see core.interpreters.inverse.


To define more complex and interesting programs, we'll use some latent random variables, i.e. random variables with unobserved values. Let's refer to the latent_normal program that samples a random value z that is used as the mean of another random value x.

def latent_normal(key):
  z_key, x_key = random.split(key)
  z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)

In this program, z is latent so if we were to just call latent_normal(random.PRNGKey(0)) we would not know the actual value of z that is responsible for generating x.

joint_sample is a transformation that transforms a program into another program that returns a dictionary mapping string names (tags) to their values. In order to work, we need to make sure we tag the latent variables to ensure they appear in the transformed function's output.

{'x': DeviceArray(0.01873656, dtype=float32),
 'z': DeviceArray(0.14389044, dtype=float32)}

Note that joint_sample transforms a program into another program that samples the joint distribution over its latent values, so we can further transform it. For algorithms like MCMC and VI, it's common to compute the log probability of the joint distribution as part of the inference procedure. log_prob(latent_normal) doesn't work because it requires marginalizing out z, but we can use log_prob(joint_sample(latent_normal)).

print(log_prob(joint_sample(latent_normal))(dict(x=0., z=1.)))
print(log_prob(joint_sample(latent_normal))(dict(x=0., z=-10.)))

Because this is such a common pattern, Oryx also has a joint_log_prob transformation which is just the composition of log_prob and joint_sample.

print(joint_log_prob(latent_normal)(dict(x=0., z=1.)))
print(joint_log_prob(latent_normal)(dict(x=0., z=-10.)))


The block transformation takes in a program and a sequence of names and returns a program that behaves identically except that in downstream transformations (like joint_sample), the provided names are ignored. An example of where block is handy is converting a joint distribution into a prior over the latent variables by "blocking" the values sampled in the likelihood. For example, take latent_normal, which first draws a z ~ N(0, 1) then an x | z ~ N(z, 1e-1). block(latent_normal, names=['x']) is a program that hides the x name, so if we do joint_sample(block(latent_normal, names=['x'])), we obtain a dictionary with just z in it.

blocked = block(latent_normal, names=['x'])
{'z': DeviceArray(0.14389044, dtype=float32)}


The intervene transformation clobbers samples in a probabilistic program with values from the outside. Going back to our latent_normal program, let's say we were interested in running the same program but wanted z to be fixed to 4. Rather than writing a new program, we can use intervene to override the value of z.

intervened = intervene(latent_normal, z=4.)
sns.distplot(vmap(intervened)(random.split(random.PRNGKey(0), 10000)));
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/ FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)


The intervened function samples from p(x | do(z = 4)) which is just a standard normal distribution centered at 4. When we intervene on a particular value, that value is no longer considered a random variable. This means that a z value will not be tagged while executing intervened.


conditional transforms a program that samples latent values into one that conditions on those latent values. Returning to our latent_normal program, which samples p(x) with a latent z, we can convert it into a conditional program p(x | z).

cond_program = conditional(latent_normal, 'z')
print(cond_program(random.PRNGKey(0), 100.))
print(cond_program(random.PRNGKey(0), 50.))
sns.distplot(vmap(lambda key: cond_program(key, 1.))(random.split(random.PRNGKey(0), 10000)))
sns.distplot(vmap(lambda key: cond_program(key, 2.))(random.split(random.PRNGKey(0), 10000)))
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/ FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/ FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)



When we start composing probabilistic programs to build more complex ones, it's common to reuse functions that have some important logic. For example, if we'd like to build a Bayesian neural network, there might be an important dense program that samples weights and executes a forward pass.

If we reuse functions, however, we might end up with duplicate tagged values in the final program, which is disallowed by transformations like joint_sample. We can use the nest to create tag "scopes" where any samples inside of a named scope will be inserted into a nested dictionary.

def f(key):
  return random_variable(tfd.Normal(0., 1.), name='x')(key)

def g(key):
  k1, k2 = random.split(key)
  return nest(f, scope='x1')(k1) + nest(f, scope='x2')(k2)
{'x1': {'x': DeviceArray(0.14389044, dtype=float32)},
 'x2': {'x': DeviceArray(-1.2515389, dtype=float32)} }

Case study: Bayesian neural network

Let's try our hand at training a Bayesian neural network for classifying the classic Fisher Iris dataset. It's relatively small and low-dimensional so we can try directly sampling the posterior with MCMC.

First, let's import the dataset and some additional utilities from Oryx.

from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

from oryx.experimental import mcmc
from oryx.util import summary, get_summaries

We begin by implementing a dense layer, which will have normal priors over the weights and bias. To do this, we first define a dense higher order function that takes in the desired output dimension and activation function. The dense function returns a probabilistic program that represents a conditional distribution p(h | x) where h is the output of a dense layer and x is its input. It first samples the weight and bias and then applies them to x.

def dense(dim_out, activation=jax.nn.relu):
  def forward(key, x):
    dim_in = x.shape[-1]
    w_key, b_key = random.split(key)
    w = random_variable(
          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
    b = random_variable(
          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
    return activation(, x) + b)
  return forward

To compose several dense layers together, we will implement an mlp (multilayer perceptron) higher order function which takes in a list of hidden sizes and a number of classes. It returns a program that repeatedly calls dense using the appropriate hidden_size and finally returns logits for each class in the final layer. Note the use of nest which creates name scopes for each layer.

def mlp(hidden_sizes, num_classes):
  num_hidden = len(hidden_sizes)
  def forward(key, x):
    keys = random.split(key, num_hidden + 1)
    for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
      x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
    logits = nest(dense(num_classes, activation=lambda x: x),
                  scope=f'layer_{num_hidden + 1}')(keys[-1], x)
    return logits
  return forward

To implement the full model, we'll need to model the labels as categorical random variables. We'll define a predict function which takes in a dataset of xs (the features) which are then passed into an mlp using vmap. When we use vmap(partial(mlp, mlp_key)), we sample a single set of weights, but map the forward pass over all the input xs. This produces a set of logits which parameterizes independent categorical distributions.

def predict(mlp):
  def forward(key, xs):
    mlp_key, label_key = random.split(key)
    logits = vmap(partial(mlp, mlp_key))(xs)
    return random_variable(
        tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)
  return forward

That's the full model! Let's use MCMC to sample the posterior of the BNN weights given data; first we construct a BNN "template" using mlp.

bnn = mlp([200, 200], num_classes)

To construct a starting point for our Markov chain, we can use joint_sample with a dummy input.

weights = joint_sample(bnn)(random.PRNGKey(0), jnp.ones(num_features))
dict_keys(['layer_1', 'layer_2', 'layer_3'])

Computing the joint distribution log probability is sufficient for many inference algorithms. Let's now say we observe x and want to sample the posterior p(z | x). For complex distributions, we won't be able to marginalize out x (though for latent_normal we can) but we can compute an unnormalized log density log p(z, x) where x is fixed to a particular value. We can use the unnormalized log probability with MCMC to sample the posterior. Let's write this "pinned" log prob function.

def target_log_prob(weights):
  return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)

Now we can use tfp.mcmc to sample the posterior using our unnormalized log density function. Note that we'll have to use a "flattened" version of our nested weights dictionary to be compatible with tfp.mcmc, so we use JAX's tree utilities to flatten and unflatten.

def run_chain(key, weights):
  flat_state, sample_tree = jax.tree_flatten(weights)

  def flat_log_prob(*states):
    return target_log_prob(jax.tree_unflatten(sample_tree, states))

  def trace_fn(_, results):
    return results.inner_results.accepted_results.target_log_prob

  flat_states, log_probs = tfp.mcmc.sample_chain(
        tfp.mcmc.HamiltonianMonteCarlo(flat_log_prob, 1e-3, 100),
        9000, target_accept_prob=0.7),
  samples = jax.tree_unflatten(sample_tree, flat_states)
  return samples, log_probs
posterior_weights, log_probs = run_chain(random.PRNGKey(0), weights)


We can use our samples to take a Bayesian model averaging (BMA) estimate of the training accuracy. To compute it, we can use intervene with bnn to "inject" posterior weights in place of the ones that are sampled from the key. To compute logits for each data point for each posterior sample, we can double vmap over posterior_weights and features.

output_logits = vmap(lambda weights: vmap(lambda x: intervene(bnn, **weights)(
    random.PRNGKey(0), x))(features))(posterior_weights)
output_probs = jax.nn.softmax(output_logits)
print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
Average sample accuracy: 0.9874067
BMA accuracy: 0.99333334


In Oryx, probabilistic programs are just JAX functions that take in (pseudo-)randomness as an input. Because of Oryx's tight integration with JAX's function transformation system, we can write and manipulate probabilistic programs like we're writing JAX code. This results in a simple but flexible system for building complex models and doing inference.