Module: oryx.core.ppl.transformations

Module for probabilistic programming transformations.

Probabilistic programs

A probabilistic program is defined as a JAX function that takes in a `jax.random.PRNGKey as its first input, and any number of subsequent conditioning arguments. The output of the function is the output of the probabilistic program.

A simple program:

def f(key):
  return random.normal(key)

In this program, we sample a random normal variable and return it. Conceptually, this program represents the distribution p(x) = Normal(0, 1).

A conditional program:

def f(key, z):
  return z * random.normal(key)

In this program we sample a distribution conditional on z (i.e. a distribution p(x | z)).

Function transformations

The goal of the probabilistic programming package is to enable writing simple programs and to use program transformations to create complexity. Here we outline some of the transformations available in the module.

random_variable

random_variable is a general purpose function that can be used to 1) tag values for use in downstream transforms and 2) convert objects into probabilistic programs. In implementation, random_variable is a single-dispatch function whose implementation for functions and objects is already registered. By default, it will tag a value with a name and will only work on JAX types (e.g. DeviceArrays and tracers). We also register an implementation for function types, where it returns the original function but when provided the name, tags the output of the function. The registry enables objects such as TensorFlow Probability distributions to register as as random variable-like with Oryx.

Tagging a value in a probabilistic program as a random variable enables it to be used by downstream transforms described below, such as joint_sample, conditional, intervene, and graph_replace.

log_prob

log_prob takes a probabilistic program and returns a function that computes the log probability of a sample. It relies on the fact that certain sampling primitives have been registered with the transformation. Specifically, it returns a program that when provided an output from the program attempts to compute the log-probability of all random samples in the program.

Examples:

def f1(key):
  return random_normal(key)
log_prob(f1)(0.)  # ==> -0.9189385
log_prob(f1)(1.)  # ==> -1.4189385

def f2(key):
  k1, k2 = random.split(key)
  return [random_normal(k1), random_normal(k2)]
log_prob(f2)([0., 0.])  # ==> -1.837877

For programs that sample variables that aren't returned as part of the output of the program (latent variables), the log_prob of the program will error, because there is insufficient information from the output of the program to compute the log probabilities of all the random samples in the program.

def f(key):
  k1, k2 = random.split(key)
  z = random_normal(k1)
  return random_normal(k2) + z
log_prob(f)(0.)  # ==> Error!

In this case, we can use joint_sample to transform it into one that returns values for all the latent variables, and log_prob will compute the joint log-probability of the function.

log_prob is also able to invert bijective functions and compute the change-of-variables formula for probabilistic programs. For more details, see oryx.core.interpreters.log_prob.

def f(key):
  return np.exp(random_normal(key))
log_prob(f)(np.exp(0.))  # ==> -0.9189385
log_prob(f)(np.exp(1.))  # ==> -2.4189386

joint_sample

joint_sample takes a probabilistic program and returns another one that returns a dictionary mapping named latent variables (tagged by random_variable) to their latent values during execution.

Example:

def f(key):
  k1, k2 = random.split(key)
  z = random_variable(random.normal, name='z')(k1)
  return z + random_variable(random.normal, name='x')(k2)
joint_sample(f)(random.PRNGKey(0))  # ==> {'z': -0.0495, 'x': 0.193}

joint_log_prob

joint_log_prob takes a probabilistic program and returns a function that computes a log probability of dictionary mapping names to values corresponding to random variables during the program's execution. It is the composition of log_prob and joint_sample.

Example:

def f(key):
  k1, k2 = random.split(key)
  z = random_variable(random.normal, name='z')(k1)
  return z + random_variable(random.normal, name='x')(k2)
joint_log_prob(f)({'z': 0., 'x': 0.})  # ==> -1.837877

intervene

intervene takes a probabilistic program and a dictionary mapping names to values of intervened random variables, and returns a new probabilistic program. The new program runs the original, but when sampling a tagged random variable whose name is present in the dictionary, it instead substitutes in the provided value.

def f1(key):
  return random_variable(random.normal, name='x')(key)
intervene(f1, x=1.)(random.PRNGKey(0))  # => 1.

def f2(key):
  k1, k2 = random.split(key)
  z = random_variable(random.normal, name='z')(k1)
  return z + random_variable(random.normal, name='x')(k2)
intervene(f2, z=1., x=1.)(random.PRNGKey(0))  # => 2.

conditional

conditional is similar to intervene, except instead of taking a dictionary of observations, it takes a list of names and returns a conditional probabilistic program which takes additional arguments corresponding to random variables with the aforementioned list of names.

Example:

def f(key):
  k1, k2 = random.split(key)
  z = random_variable(random.normal, name='z')(k1)
  return z + random_variable(random.normal, name='x')(k2)
conditional(f, ['z'])(random.PRNGKey(0), 0.)  # => -1.25153887
conditional(f, ['z'])(random.PRNGKey(0), 1.)  # => -0.25153887
conditional(f, ['z'. 'x'])(random.PRNGKey(0), 1., 2.)  # => 3.

graph_replace

graph_replace is a transformation that executes the original program but with new inputs and outputs specified by random variable names. Input names allow injecting values for random variables in the program, and the values of random variables corresponding to output names are returned.

Example:

def f(key):
  k1, k2, k3 = random.split(key, 3)
  z = random_variable(random_normal, name='z')(k1)
  x = random_variable(lambda key: random_normal(key) + z, name='x')(k2)
  y = random_variable(lambda key: random_normal(key) + x, name='y')(k3)
  return y
graph_replace(f, 'z', 'y') # returns a program p(y | z) with a latent variable x
graph_replace(f, 'z', 'x') # returns a program p(x | z)
graph_replace(f, 'x', 'y') # returns a program p(y | x)

Functions

conditional(...): Conditions a probabilistic program on random variables.

graph_replace(...): Transforms a program to one with new inputs and outputs.

intervene(...): Transforms a program into one where provided random variables are fixed.

joint_log_prob(...): Returns a function that computes the log probability of all of a program's random variables.

joint_sample(...): Returns a program that outputs a dictionary of latent random variable samples.

log_prob(...): Returns a function that computes the log probability of a sample.

nest(...): Wraps a function to create a new scope for harvested values.

random_variable(...): A single-dispatch function used to tag values and the outputs of programs.

rv(...): A single-dispatch function used to tag values and the outputs of programs.