Module: oryx.core.ppl.effect_handler

Enables writing custom effect handlers for probabilistic programs.

Background

Oryx's PPL system comes with built-in transformations such as joint_sample and log_prob which enable manipulating probabilistic programs. However, the built-in transformations do not support many types of program manipulation and transformation.

Consider, for example, a noncentering program transformation which takes location-scale random variables in a program and adjusts them to sample from a zero-one location-scale and then shift and scale the resulting sample. This transformation can be important for well-conditioned MCMC and VI but cannot be expressed as some combination of the built-in ones -- transformations like joint_sample and log_prob operate using tagged values in a program but cannot actually adjust what happens at a particular sample site.

Custom interpreters for effect handling

A custom interpreter for a probabilistic program is the most general purpose tool for building a transformation. A custom interpreter first traces a program into a JAXpr and then executes the JAXpr with modified execution rules -- for example, a (silly) custom interpreter could execute a program exactly as written but could add 1 to the result of every call to exp. A more complicated custom interpreter could apply a noncentering transformation to each location-scale family random variable.

Writing a custom interpreter from scratch for each possible transformation, however, can be tedious. Instead, in this module, we provide a simpler API to a more restrictive, but still useful, set of custom interpreters. In particular, we're interested in intercepting certain operations (effects) and handling them by perhaps inserting new ones in their place (e.g. we could intercept all normal sampling operations and replace them with their noncentered versions) or updating a running state (e.g. accumulating log_probs at each sample site). This technique is called effect handling and is used by libraries such as Edward2 and Pyro.

Usage

The main function is make_effect_handler which takes in a dictionary, handlers, that maps JAX primitives to a handler function. The handler function takes in the handler state and the usual arguments to the JAX primitive. It can then return a new value (intercepting the regular call to the JAX primitive) and an updated state which will be passed into later handlers. The effect_handler's default behavior when a handler is not provided for a primitive is to execute the primitive normally and return an unchanged state.

The result of make_effect_handler is a function transformation that when applied to a function f returns a transformed function that takes in an initial state argument and returns the final state as an additional output.

Example: add 1 to exp

This handler adds 1 to the output of each call to exp.

def exp_rule(state, value):
  # Leave state unchanged.
  return jnp.exp(value) + 1., state
add_one_exp = make_effect_handler({lax.exp_p: exp_rule})

# We can transform functions with `add_one_exp`.
def f(x):
  return jnp.exp(x) * 2.
add_one_exp(f)(None, 2.) # Executes (exp(x) + 1.) * 2.

Example: count number of adds

This handler counts the number of times add is called.

def add_rule(count, x, y):
  # Increment count
  return x + y, count + 1
count_adds = make_effect_handler({lax.add_p: add_rule})

# We can transform functions with `count_adds`. The first argument to
# `count_adds` is input state, and it returns the final state as the second
# output.
def f(x):
  return x + 1. + 2.
count_adds(f)(0, 2.) # Returns (5., 2)

Functions

make_effect_handler(...): Returns a function transformation that applies a provided set of handlers.