Module: oryx.core.ppl.plate_util
bookmark_borderbookmark
Stay organized with collections
Save and categorize content based on your preferences.
Contains utilities for the plate
transformation.
A plate is a term in graphical models that is used to designate independent
random variables. In Oryx, plate
is a transformation that converts a program
into one that produces independent samples. Ordinarily, this can be done with
jax.vmap
, where we could split several random keys and map a program over
them. Unlike jax.vmap
, plate
operates using named axes. A plate
-ed
program will specialize the random seed to the particular index of the axis
being mapped over. Taking the log_prob
of a plate
program will reduce over
the named axis. In design, plate
resembles the Sharded
meta-distribution
from TensorFlow Probability.
In implementation, plate
is an Oryx HigherOrderPrimitive
(i.e. a JAX
CallPrimitive
with a log_prob
rule that reduces over a named axis at the
end.
Functions
make_plate(...)
: Wraps a probabilistic program in a plate with a named axis.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-05-23 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-05-23 UTC."],[],[],null,["# Module: oryx.core.ppl.plate_util\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/jax-ml/oryx/tree/main/oryx/core/ppl/plate_util.py) |\n\nContains utilities for the `plate` transformation.\n\nA plate is a term in graphical models that is used to designate independent\nrandom variables. In Oryx, `plate` is a transformation that converts a program\ninto one that produces independent samples. Ordinarily, this can be done with\n`jax.vmap`, where we could split several random keys and map a program over\nthem. Unlike `jax.vmap`, `plate` operates using named axes. A `plate`-ed\nprogram will specialize the random seed to the particular index of the axis\nbeing mapped over. Taking the `log_prob` of a `plate` program will reduce over\nthe named axis. In design, `plate` resembles the `Sharded` meta-distribution\nfrom TensorFlow Probability.\n\nIn implementation, `plate` is an Oryx `HigherOrderPrimitive` (i.e. a JAX\n`CallPrimitive` with a `log_prob` rule that reduces over a named axis at the\nend.\n\nFunctions\n---------\n\n[`make_plate(...)`](../../../oryx/core/ppl/plate_util/make_plate): Wraps a probabilistic program in a plate with a named axis."]]