|View source on GitHub|
Transforms a program into one that draws samples on a named axis.
oryx.core.ppl.plate( f: Optional[
oryx.core.ppl.LogProbFunction] = None, name: Optional[str] = None )
In graphical model parlance, a plate designates independent random variables.
plate transformation follows this idea, where a
draws independent samples. Unlike
jax.vmap-ing a program, which also
produces independent samples with positional batch dimensions,
produces samples with implicit named axes. Named axis support is useful for
other JAX transformations like
plate-ed program creates a different key for each axis
of the named axis.
log_prob reduces over the named axis to produce a single
@ppl.plate(name='foo') def model(key): return random_variable(random.normal)(key) # We can't call model directly because there are implicit named axes present try: model(random.PRNGKey(0)) except NameError: print('No named axis present!') # If we vmap with a named axis, we produce independent samples. vmap(model, axis_name='foo')(random.split(random.PRNGKey(0), 3)) # ==> [0.58776844, -0.4009751, 0.01193586]
A decorator if