View source on GitHub |
Enables writing custom effect handlers for probabilistic programs.
Background
Oryx's PPL system comes with built-in transformations such as joint_sample
and log_prob
which enable manipulating probabilistic programs. However,
the built-in transformations do not support many types of program manipulation
and transformation.
Consider, for example, a noncentering program transformation
which takes location-scale random variables in a program and adjusts them to
sample from a zero-one location-scale and then shift and scale the resulting
sample. This transformation can be important for well-conditioned MCMC and VI
but cannot be expressed as some combination of the built-in ones --
transformations like joint_sample
and log_prob
operate using tagged values
in a program but cannot actually adjust what happens at a particular sample
site.
Custom interpreters for effect handling
A custom interpreter for a probabilistic program is the most general purpose
tool for building a transformation. A custom interpreter first traces a program
into a JAXpr and then executes the JAXpr with modified execution rules -- for
example, a (silly) custom interpreter could execute a program exactly as written
but could add 1 to the result of every call to exp
. A more complicated
custom interpreter could apply a noncentering transformation to each
location-scale family random variable.
Writing a custom interpreter from scratch for each possible transformation,
however, can be tedious. Instead, in this module, we provide a simpler API to a
more restrictive, but still useful, set of custom interpreters. In particular,
we're interested in intercepting certain operations (effects) and
handling them by perhaps inserting new ones in their place (e.g. we could
intercept all normal sampling operations and replace them with their noncentered
versions) or updating a running state (e.g. accumulating log_prob
s at each
sample site). This technique is called effect handling and is used by
libraries such as Edward2 and
Pyro.
Usage
The main function is make_effect_handler
which takes in a dictionary,
handlers
, that maps JAX primitives to a handler function. The handler function
takes in the handler state and the usual arguments to the JAX primitive. It
can then return a new value (intercepting the regular call to the JAX primitive)
and an updated state which will be passed into later handlers. The
effect_handler
's default behavior when a handler is not provided for a
primitive is to execute the primitive normally and return an unchanged state.
The result of make_effect_handler
is a function transformation that
when applied to a function f
returns a transformed function that takes in an
initial state argument and returns the final state as an additional output.
Example: add 1 to exp
This handler adds 1 to the output of each call to exp
.
def exp_rule(state, value):
# Leave state unchanged.
return jnp.exp(value) + 1., state
add_one_exp = make_effect_handler({lax.exp_p: exp_rule})
# We can transform functions with `add_one_exp`.
def f(x):
return jnp.exp(x) * 2.
add_one_exp(f)(None, 2.) # Executes (exp(x) + 1.) * 2.
Example: count number of add
s
This handler counts the number of times add
is called.
def add_rule(count, x, y):
# Increment count
return x + y, count + 1
count_adds = make_effect_handler({lax.add_p: add_rule})
# We can transform functions with `count_adds`. The first argument to
# `count_adds` is input state, and it returns the final state as the second
# output.
def f(x):
return x + 1. + 2.
count_adds(f)(0, 2.) # Returns (5., 2)
Functions
make_effect_handler(...)
: Returns a function transformation that applies a provided set of handlers.