oryx.core.ppl.plate

Transforms a program into one that draws samples on a named axis.

In graphical model parlance, a plate designates independent random variables. The plate transformation follows this idea, where a plate-ed program draws independent samples. Unlike jax.vmap-ing a program, which also produces independent samples with positional batch dimensions, plate produces samples with implicit named axes. Named axis support is useful for other JAX transformations like pmap and xmap.

Specifically, a plate-ed program creates a different key for each axis of the named axis. log_prob reduces over the named axis to produce a single value.

Example usage:

@ppl.plate(name='foo')
def model(key):
  return random_variable(random.normal)(key)
# We can't call model directly because there are implicit named axes present
try:
  model(random.PRNGKey(0))
except NameError:
  print('No named axis present!')
# If we vmap with a named axis, we produce independent samples.
vmap(model, axis_name='foo')(random.split(random.PRNGKey(0), 3))
# ==> [0.58776844, -0.4009751, 0.01193586]

f a Program to transform. If f is None, plate returns a decorator.
name a str name for the plate which can used as a name axis in JAX functions and transformations.

A decorator if f is None or a transformed program if f is provided. The transformed program behaves produces independent across a named axis with name name.