Module: oryx.core.ppl.plate_util

Contains utilities for the plate transformation.

A plate is a term in graphical models that is used to designate independent random variables. In Oryx, plate is a transformation that converts a program into one that produces independent samples. Ordinarily, this can be done with jax.vmap, where we could split several random keys and map a program over them. Unlike jax.vmap, plate operates using named axes. A plate-ed program will specialize the random seed to the particular index of the axis being mapped over. Taking the log_prob of a plate program will reduce over the named axis. In design, plate resembles the Sharded meta-distribution from TensorFlow Probability.

In implementation, plate is an Oryx HigherOrderPrimitive (i.e. a JAX CallPrimitive with a log_prob rule that reduces over a named axis at the end.

Functions

make_plate(...): Wraps a probabilistic program in a plate with a named axis.