# Module: oryx.core.ppl.transformations

Module for probabilistic programming transformations.

## Probabilistic programs

A probabilistic program is defined as a JAX function that takes in a `jax.random.PRNGKey as its first input, and any number of subsequent conditioning arguments. The output of the function is the output of the probabilistic program.

### A simple program:

``````def f(key):
return random.normal(key)
``````

In this program, we sample a random normal variable and return it. Conceptually, this program represents the distribution `p(x) = Normal(0, 1)`.

### A conditional program:

``````def f(key, z):
return z * random.normal(key)
``````

In this program we sample a distribution conditional on `z` (i.e. a distribution `p(x | z)`).

## Function transformations

The goal of the probabilistic programming package is to enable writing simple programs and to use program transformations to create complexity. Here we outline some of the transformations available in the module.

### `random_variable`

`random_variable` is a general purpose function that can be used to 1) tag values for use in downstream transforms and 2) convert objects into probabilistic programs. In implementation, `random_variable` is a single-dispatch function whose implementation for functions and objects is already registered. By default, it will tag a value with a name and will only work on JAX types (e.g. DeviceArrays and tracers). We also register an implementation for function types, where it returns the original function but when provided the name, tags the output of the function. The registry enables objects such as TensorFlow Probability distributions to register as as random variable-like with Oryx.

Tagging a value in a probabilistic program as a random variable enables it to be used by downstream transforms described below, such as `joint_sample`, `conditional`, `intervene`, and `graph_replace`.

### `log_prob`

`log_prob` takes a probabilistic program and returns a function that computes the log probability of a sample. It relies on the fact that certain sampling primitives have been registered with the transformation. Specifically, it returns a program that when provided an output from the program attempts to compute the log-probability of all random samples in the program.

#### Examples:

``````def f1(key):
return random_normal(key)
log_prob(f1)(0.)  # ==> -0.9189385
log_prob(f1)(1.)  # ==> -1.4189385

def f2(key):
k1, k2 = random.split(key)
return [random_normal(k1), random_normal(k2)]
log_prob(f2)([0., 0.])  # ==> -1.837877
``````

For programs that sample variables that aren't returned as part of the output of the program (latent variables), the `log_prob` of the program will error, because there is insufficient information from the output of the program to compute the log probabilities of all the random samples in the program.

``````def f(key):
k1, k2 = random.split(key)
z = random_normal(k1)
return random_normal(k2) + z
log_prob(f)(0.)  # ==> Error!
``````

In this case, we can use `joint_sample` to transform it into one that returns values for all the latent variables, and `log_prob` will compute the joint log-probability of the function.

`log_prob` is also able to invert bijective functions and compute the change-of-variables formula for probabilistic programs. For more details, see `oryx.core.interpreters.log_prob`.

``````def f(key):
return np.exp(random_normal(key))
log_prob(f)(np.exp(0.))  # ==> -0.9189385
log_prob(f)(np.exp(1.))  # ==> -2.4189386
``````

### `joint_sample`

`joint_sample` takes a probabilistic program and returns another one that returns a dictionary mapping named latent variables (tagged by `random_variable`) to their latent values during execution.

#### Example:

``````def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
joint_sample(f)(random.PRNGKey(0))  # ==> {'z': -0.0495, 'x': 0.193}
``````

### `joint_log_prob`

`joint_log_prob` takes a probabilistic program and returns a function that computes a log probability of dictionary mapping names to values corresponding to random variables during the program's execution. It is the composition of `log_prob` and `joint_sample`.

#### Example:

``````def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
joint_log_prob(f)({'z': 0., 'x': 0.})  # ==> -1.837877
``````

### `block`

`block` takes a probabilistic program and a sequence of string names and returns the same program except that downstream transformations will ignore the provided names.

#### Example:

``````def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
joint_sample(block(f, names=['x']))(random.PRNGKey(0))  # ==> {'z': -0.0495}
``````

### `intervene`

`intervene` takes a probabilistic program and a dictionary mapping names to values of intervened random variables, and returns a new probabilistic program. The new program runs the original, but when sampling a tagged random variable whose name is present in the dictionary, it instead substitutes in the provided value.

``````def f1(key):
return random_variable(random.normal, name='x')(key)
intervene(f1, x=1.)(random.PRNGKey(0))  # => 1.

def f2(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
intervene(f2, z=1., x=1.)(random.PRNGKey(0))  # => 2.
``````

### `conditional`

`conditional` is similar to `intervene`, except instead of taking a dictionary of observations, it takes a list of names and returns a conditional probabilistic program which takes additional arguments corresponding to random variables with the aforementioned list of names.

#### Example:

``````def f(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
conditional(f, ['z'])(random.PRNGKey(0), 0.)  # => -1.25153887
conditional(f, ['z'])(random.PRNGKey(0), 1.)  # => -0.25153887
conditional(f, ['z'. 'x'])(random.PRNGKey(0), 1., 2.)  # => 3.
``````

### `graph_replace`

`graph_replace` is a transformation that executes the original program but with new inputs and outputs specified by random variable names. Input names allow injecting values for random variables in the program, and the values of random variables corresponding to output names are returned.

#### Example:

``````def f(key):
k1, k2, k3 = random.split(key, 3)
z = random_variable(random_normal, name='z')(k1)
x = random_variable(lambda key: random_normal(key) + z, name='x')(k2)
y = random_variable(lambda key: random_normal(key) + x, name='y')(k3)
return y
graph_replace(f, 'z', 'y') # returns a program p(y | z) with a latent variable x
graph_replace(f, 'z', 'x') # returns a program p(x | z)
graph_replace(f, 'x', 'y') # returns a program p(y | x)
``````

## Functions

`block(...)`: Returns a program that removes the provided names from transformations.

`conditional(...)`: Conditions a probabilistic program on random variables.

`graph_replace(...)`: Transforms a program to one with new inputs and outputs.

`intervene(...)`: Transforms a program into one where provided random variables are fixed.

`joint_log_prob(...)`: Returns a function that computes the log probability of all of a program's random variables.

`joint_sample(...)`: Returns a program that outputs a dictionary of latent random variable samples.

`log_prob(...)`: Returns a function that computes the log probability of a sample.

`nest(...)`: Wraps a function to create a new scope for harvested values.

`random_variable(...)`: A single-dispatch function used to tag values and the outputs of programs.

`rv(...)`: A single-dispatch function used to tag values and the outputs of programs.

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"Missing the information I need" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Too complicated / too many steps" },{ "type": "thumb-down", "id": "outOfDate", "label":"Out of date" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Samples / code issue" },{ "type": "thumb-down", "id": "otherDown", "label":"Other" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Easy to understand" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Solved my problem" },{ "type": "thumb-up", "id": "otherUp", "label":"Other" }]