Attend the Women in ML Symposium on December 7

# tfp.substrates.jax.distributions.JointDistributionNamed

Stay organized with collections Save and categorize content based on your preferences.

Joint distribution parameterized by named distribution-making functions.

Inherits From: `JointDistribution`, `Distribution`

This distribution enables both sampling and joint probability computation from a single model specification.

A joint distribution is a collection of possibly interdependent distributions. Like `JointDistributionSequential`, `JointDistributionNamed` is parameterized by several distribution-making functions. Unlike `JointDistributionNamed`, each distribution-making function must have its own key. Additionally every distribution-making function's arguments must refer to only specified keys.

#### Mathematical Details

Internally `JointDistributionNamed` implements the chain rule of probability. That is, the probability function of a length-`d` vector `x` is,

``````  p(x) = prod{ p(x[i] | x[:i]) : i = 0, ..., (d - 1) }
``````

The `JointDistributionNamed` is parameterized by a `dict` (or `namedtuple` or `collections.OrderedDict`) composed of either:

1. `tfp.distributions.Distribution`-like instances or,
2. `callable`s which return a `tfp.distributions.Distribution`-like instance.

The "conditioned on" elements are represented by the `callable`'s required arguments; every argument must correspond to a key in the named distribution-making functions. Distribution-makers which are directly a `Distribution`-like instance are allowed for convenience and semantically identical a zero argument `callable`. When the maker takes no arguments it is preferable to directly provide the distribution instance.

Name resolution: `The names of`JointDistributionNamed` components are simply the keys specified explicitly in the model definition.

#### Examples

Consider the following generative model:

``````e ~ Exponential(rate=[100,120])
g ~ Gamma(concentration=e[0], rate=e[1])
n ~ Normal(loc=0, scale=2.)
m ~ Normal(loc=n, scale=g)
for i = 1, ..., 12:
x[i] ~ Bernoulli(logits=m)
``````

We can code this as:

``````tfd = tfp.distributions
joint = tfd.JointDistributionNamed(dict(
e=             tfd.Exponential(rate=[100, 120]),
g=lambda    e: tfd.Gamma(concentration=e[0], rate=e[1]),
n=             tfd.Normal(loc=0, scale=2.),
m=lambda n, g: tfd.Normal(loc=n, scale=g),
x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12)
),
batch_ndims=0,
use_vectorized_map=True)
``````

Notice the 1:1 correspondence between "math" and "code". Further, notice that unlike `JointDistributionSequential`, there is no need to put the distribution-making functions in topologically sorted order nor is it ever necessary to use dummy arguments to skip dependencies.

``````x = joint.sample()
# ==> A 5-element `dict` of Tensors representing a draw/realization from each
#     distribution.
joint.log_prob(x)
# ==> A scalar `Tensor` representing the total log prob under all five
#     distributions.

joint.resolve_graph()
# ==> (('e', ()),
#      ('g', ('e',)),
#      ('n', ()),
#      ('m', ('n', 'g')),
#      ('x', ('m',)))
``````

#### Discussion

`JointDistributionNamed` topologically sorts the distribution-making functions and calls each by feeding in all previously created dependencies. A distribution-maker must either be a:

3. `tfd.Distribution`-like instance (e.g., `e` and `n` in the above example),

4. Python `callable` (e.g., `g`, `m`, `x` in the above example).

Regarding #1, an object is deemed "`tfd.Distribution`-like" if it has a `sample`, `log_prob`, and distribution properties, e.g., `batch_shape`, `event_shape`, `dtype`.

Regarding #2, in addition to using a function (or `lambda`), supplying a TFD "`class`" is also permissible, this also being a "Python `callable`." For example, instead of writing: `lambda loc, scale: tfd.Normal(loc=loc, scale=scale)` one could have simply written `tfd.Normal`.

Notice that directly providing a `tfd.Distribution`-like instance means there cannot exist a (dynamic) dependency on other distributions; it is "independent" both "computationally" and "statistically." The same is self-evidently true of zero-argument `callable`s.

A distribution instance depends on other distribution instances through the distribution making function's required arguments. The distribution makers' arguments are parameterized by samples from the corresponding previously constructed distributions. ("Previous" in the sense of a topological sorting of dependencies.)

#### Vectorized sampling and model evaluation

When a joint distribution's `sample` method is called with a `sample_shape` (or the `log_prob` method is called on an input with multiple sample dimensions) the model must be equipped to handle additional batch dimensions. This may be done manually, or automatically by passing `use_vectorized_map=True`. Manual vectorization has historically been the default, but we now recommend that most users enable automatic vectorization unless they are affected by a specific issue; some known issues are listed below.

When using manually-vectorized joint distributions, each operation in the model must account for the possibility of batch dimensions in Distributions and their samples. By contrast, auto-vectorized models need only describe a single sample from the joint distribution; any batch evaluation is automated as required using `tf.vectorized_map` (`vmap` in JAX). In many cases this allows for significant simplications. For example, the following manually-vectorized `tfd.JointDistributionSequential` model:

``````model = tfd.JointDistributionSequential([
tfd.Normal(0., tf.ones([3])),
tfd.Normal(0., 1.),
lambda y, x: tfd.Normal(x[..., :2] + y[..., tf.newaxis], 1.)
])
``````

can be written in auto-vectorized form as

``````model = tfd.JointDistributionSequential([
tfd.Normal(0., tf.ones([3])),
tfd.Normal(0., 1.),
lambda y, x: tfd.Normal(x[:2] + y, 1.)
],
use_vectorized_map=True)
``````

in which we were able to avoid explicitly accounting for batch dimensions when indexing and slicing computed quantities in the third line.

Known limitations of automatic vectorization:

• A small fraction of TensorFlow ops are unsupported; models that use an unsupported op will raise an error and must be manually vectorized.
• Sampling large batches may be slow under automatic vectorization because TensorFlow's stateless samplers are currently converted using a non-vectorized `while_loop`. This limitation applies only in TensorFlow; vectorized samplers in JAX should be approximately as fast as manually vectorized code.
• Calling `sample_distributions` with nontrivial `sample_shape` will raise an error if the model contains any distributions that are not registered as CompositeTensors (TFP's basic distributions are usually fine, but support for wrapper distributions like `tfd.Sample` is a work in progress).

#### Batch semantics and (log-)densities

tl;dr: pass `batch_ndims=0` unless you have a good reason not to.

Joint distributions now support 'auto-batching' semantics, in which the distribution's batch shape is derived by broadcasting the leftmost `batch_ndims` dimensions of its components' batch shapes. All remaining dimensions are considered to form a single 'event' of the joint distribution. If `batch_ndims==0`, then the joint distribution has batch shape `[]`, and all component dimensions are treated as event shape. For example, the model

``````jd = tfd.JointDistributionSequential([
tfd.Normal(0., tf.ones([3])),
lambda x: tfd.Normal(x[..., tf.newaxis], tf.ones([3, 2]))
],
batch_ndims=0)
``````

creates a joint distribution with batch shape `[]` and event shape `([3], [3, 2])`. The log-density of a sample always has shape `batch_shape`, so this guarantees that `jd.log_prob(jd.sample())` will evaluate to a scalar value. We could alternately construct a joint distribution with batch shape `[3]` and event shape `([], [2])` by setting `batch_ndims=1`, in which case `jd.log_prob(jd.sample())` would evaluate to a value of shape `[3]`.

Setting `batch_ndims=None` recovers the 'classic' batch semantics (currently still the default for backwards-compatibility reasons), in which the joint distribution's `log_prob` is computed by naively summing log densities from the component distributions. Since these component densities have shapes equal to the batch shapes of the individual components, to avoid broadcasting errors it is usually necessary to construct the components with identical batch shapes. For example, the component distributions in the model above have batch shapes of `[3]` and `[3, 2]` respectively, which would raise an error if summed directly, but can be aligned by wrapping with `tfd.Independent`, as in this model:

``````jd = tfd.JointDistributionSequential([
tfd.Normal(0., tf.ones([3])),
lambda x: tfd.Independent(tfd.Normal(x[..., tf.newaxis], tf.ones([3, 2])),
reinterpreted_batch_ndims=1)
],
batch_ndims=None)
``````

Here the components both have batch shape `[3]`, so `jd.log_prob(jd.sample())` returns a value of shape `[3]`, just as in the `batch_ndims=1` case above. In fact, auto-batching semantics are equivalent to implicitly wrapping each component `dist` as ```tfd.Independent(dist, reinterpreted_batch_ndim=(dist.batch_shape.ndims - jd.batch_ndims))```; the only vestigial difference is that under auto-batching semantics, the joint distribution has a single batch shape `[3]`, while under the classic semantics the value of `jd.batch_shape` is a structure of the component batch shapes `([3], [3])`. Such structured batch shapes will be deprecated in the future, since they are inconsistent with the definition of batch shapes used elsewhere in TFP.

#### References

[1] Dan Piponi, Dave Moore, and Joshua V. Dillon. Joint distributions for TensorFlow Probability. arXiv preprint arXiv:2001.11819_,

If every element of `model` is a `CompositeTensor` or a callable, the resulting `JointDistributionNamed` is a `CompositeTensor`. Otherwise, a non-`CompositeTensor` `_JointDistributionNamed` instance is created.

`model` Python `dict`, `collections.OrderedDict`, or `namedtuple` of distribution-making functions each with required args corresponding only to other keys.
`batch_ndims` `int` `Tensor` number of batch dimensions. The `batch_shape`s of all component distributions must be such that the prefixes of length `batch_ndims` broadcast to a consistent joint batch shape. Default value: `None`.
`use_vectorized_map` Python `bool`. Whether to use `tf.vectorized_map` to automatically vectorize evaluation of the model. This allows the model specification to focus on drawing a single sample, which is often simpler, but some ops may not be supported. Default value: `False`.
`validate_args` Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`.
`experimental_use_kahan_sum` Python