![]() |
Module for probabilistic programming transformations.
Probabilistic programs
A probabilistic program is defined as a JAX function that takes in a
jax.random.PRNGKey
as its first input, and any number of subsequent
conditioning arguments. The output of the function is the output of the
probabilistic program.
A simple program:
def f(key):
return random.normal(key)
In this program, we sample a random normal variable and return it. Conceptually,
this program represents the distribution p(x) = Normal(0, 1)
.
A conditional program:
def f(key, z):
return z * random.normal(key)
In this program we sample a distribution conditional on z
(i.e. a distribution
p(x | z)
).
Function transformations
The goal of the probabilistic programming package is to enable writing simple programs and to use program transformations to create complexity. Here we outline some of the transformations available in the module.
random_variable
random_variable
is a general purpose function that can be used to 1) tag
values for use in downstream transforms and 2) convert objects into
probabilistic programs. In implementation, random_variable
is a
single-dispatch function whose implementation for functions and objects is
already registered. By default, it will tag a value with a name and will only
work on JAX types (e.g. DeviceArrays and tracers). We also register an
implementation for function types, where it returns the original function but
when provided the name, tags the output of the function. The registry enables
objects such as TensorFlow Probability distributions to register as a random
variable-like with Oryx.
Tagging a value in a probabilistic program as a random variable enables it to
be used by downstream transforms described below, such as joint_sample
,
conditional
, intervene
, and graph_replace
.
log_prob
log_prob
takes a probabilistic program and returns a function that computes
the log probability of a sample. It relies on the fact that certain sampling
primitives have been registered with the transformation. Specifically, it
returns a program that when provided an output from the program attempts to
compute the log-probability of all random samples in the program.
Examples:
def f1(key):
return random_normal(key)
log_prob(f1)(0.) # ==> -0.9189385
log_prob(f1)(1.) # ==> -1.4189385
def f2(key):
k1, k2 = random.split(key)
return [random_normal(k1), random_normal(k2)]
log_prob(f2)([0., 0.]) # ==> -1.837877
For programs that sample variables that aren't returned as part of the output of
the program (latent variables), the log_prob
of the program will error,
because there is insufficient information from the output of the program to
compute the log probabilities of all the random samples in the program.
def f(key):
k1, k2 = random.split(key)
z = random_normal(k1)
return random_normal(k2) + z
log_prob(f)(0.) # ==> Error!
In this case, we can use joint_sample
to transform it into one that returns
values for all the latent variables, and log_prob
will compute the joint
log-probability of the function.
log_prob
is also able to invert bijective functions and compute the
change-of-variables formula for probabilistic programs. For more details,
see oryx.core.interpreters.log_prob
.
def f(key):
return np.exp(random_normal(key))
log_prob(f)(np.exp(0.)) # ==> -0.9189385
log_prob(f)(np.exp(1.)) # ==> -2.4189386
trace
trace
takes a probabilistic program and returns a program that when executed
returns both the original program's output and a dictionary that includes the
latent random variables sampled during the original program's execution.
Example:
def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
trace(f)(random.PRNGKey(0)) # ==> (0.1435, {'z': -0.0495, 'x': 0.193})
joint_sample
joint_sample
takes a probabilistic program and returns another one that
returns a dictionary mapping named latent variables (tagged by
random_variable
) to their latent values during execution.
Example:
def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
joint_sample(f)(random.PRNGKey(0)) # ==> {'z': -0.0495, 'x': 0.193}
joint_log_prob
joint_log_prob
takes a probabilistic program and returns a function that
computes a log probability of dictionary mapping names to values corresponding
to random variables during the program's execution. It is the composition of
log_prob
and joint_sample
.
Example:
def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
joint_log_prob(f)({'z': 0., 'x': 0.}) # ==> -1.837877
block
block
takes a probabilistic program and a sequence of string names and returns
the same program except that downstream transformations will ignore the provided
names.
Example:
def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
joint_sample(block(f, names=['x']))(random.PRNGKey(0)) # ==> {'z': -0.0495}
intervene
intervene
takes a probabilistic program and a dictionary mapping names to
values of intervened random variables, and returns a new probabilistic program.
The new program runs the original, but when sampling a tagged random variable
whose name is present in the dictionary, it instead substitutes in the provided
value.
def f1(key):
return random_variable(random.normal, name='x')(key)
intervene(f1, x=1.)(random.PRNGKey(0)) # => 1.
def f2(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
intervene(f2, z=1., x=1.)(random.PRNGKey(0)) # => 2.
conditional
conditional
is similar to intervene
, except instead of taking a dictionary
of observations, it takes a list of names and returns a conditional
probabilistic program which takes additional arguments corresponding to random
variables with the aforementioned list of names.
Example:
def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
conditional(f, ['z'])(random.PRNGKey(0), 0.) # => -1.25153887
conditional(f, ['z'])(random.PRNGKey(0), 1.) # => -0.25153887
conditional(f, ['z', 'x'])(random.PRNGKey(0), 1., 2.) # => 3.
graph_replace
graph_replace
is a transformation that executes the original program but
with new inputs and outputs specified by random variable names. Input names
allow injecting values for random variables in the program, and the values of
random variables corresponding to output names are returned.
Example:
def f(key):
k1, k2, k3 = random.split(key, 3)
z = random_variable(random_normal, name='z')(k1)
x = random_variable(lambda key: random_normal(key) + z, name='x')(k2)
y = random_variable(lambda key: random_normal(key) + x, name='y')(k3)
return y
graph_replace(f, 'z', 'y') # returns a program p(y | z) with a latent variable x
graph_replace(f, 'z', 'x') # returns a program p(x | z)
graph_replace(f, 'x', 'y') # returns a program p(y | x)
Functions
block(...)
: Returns a program that removes the provided names from transformations.
conditional(...)
: Conditions a probabilistic program on random variables.
graph_replace(...)
: Transforms a program to one with new inputs and outputs.
intervene(...)
: Transforms a program into one where provided random variables are fixed.
joint_log_prob(...)
: Returns a function that computes the log probability of all of a program's random variables.
joint_sample(...)
: Returns a program that outputs a dictionary of latent random variable samples.
log_prob(...)
: Returns a function that computes the log probability of a sample.
nest(...)
: Wraps a function to create a new scope for harvested values.
random_variable(...)
: A single-dispatch function used to tag values and the outputs of programs.
rv(...)
: A single-dispatch function used to tag values and the outputs of programs.