View source on GitHub
|
Contains utilities for the plate transformation.
A plate is a term in graphical models that is used to designate independent
random variables. In Oryx, plate is a transformation that converts a program
into one that produces independent samples. Ordinarily, this can be done with
jax.vmap, where we could split several random keys and map a program over
them. Unlike jax.vmap, plate operates using named axes. A plate-ed
program will specialize the random seed to the particular index of the axis
being mapped over. Taking the log_prob of a plate program will reduce over
the named axis. In design, plate resembles the Sharded meta-distribution
from TensorFlow Probability.
In implementation, plate is an Oryx HigherOrderPrimitive (i.e. a JAX
CallPrimitive with a log_prob rule that reduces over a named axis at the
end.
Functions
make_plate(...): Wraps a probabilistic program in a plate with a named axis.
View source on GitHub