View source on GitHub |
Contains utilities for the plate
transformation.
A plate is a term in graphical models that is used to designate independent
random variables. In Oryx, plate
is a transformation that converts a program
into one that produces independent samples. Ordinarily, this can be done with
jax.vmap
, where we could split several random keys and map a program over
them. Unlike jax.vmap
, plate
operates using named axes. A plate
-ed
program will specialize the random seed to the particular index of the axis
being mapped over. Taking the log_prob
of a plate
program will reduce over
the named axis. In design, plate
resembles the Sharded
meta-distribution
from TensorFlow Probability.
In implementation, plate
is an Oryx HigherOrderPrimitive
(i.e. a JAX
CallPrimitive
with a log_prob
rule that reduces over a named axis at the
end.
Functions
make_plate(...)
: Wraps a probabilistic program in a plate with a named axis.