Attend the Women in ML Symposium on December 7 Register now

tfp.substrates.jax.distributions.JointDistributionNamed

Stay organized with collections Save and categorize content based on your preferences.

Joint distribution parameterized by named distribution-making functions.

Inherits From: JointDistribution, Distribution

This distribution enables both sampling and joint probability computation from a single model specification.

A joint distribution is a collection of possibly interdependent distributions. Like JointDistributionSequential, JointDistributionNamed is parameterized by several distribution-making functions. Unlike JointDistributionNamed, each distribution-making function must have its own key. Additionally every distribution-making function's arguments must refer to only specified keys.

#### Mathematical Details

Internally JointDistributionNamed implements the chain rule of probability. That is, the probability function of a length-d vector x is,

  p(x) = prod{ p(x[i] | x[:i]) : i = 0, ..., (d - 1) }

The JointDistributionNamed is parameterized by a dict (or namedtuple or collections.OrderedDict) composed of either:

  1. tfp.distributions.Distribution-like instances or,
  2. callables which return a tfp.distributions.Distribution-like instance.

    The "conditioned on" elements are represented by the callable's required arguments; every argument must correspond to a key in the named distribution-making functions. Distribution-makers which are directly a Distribution-like instance are allowed for convenience and semantically identical a zero argument callable. When the maker takes no arguments it is preferable to directly provide the distribution instance.

    Name resolution: The names ofJointDistributionNamed` components are simply the keys specified explicitly in the model definition.

    Examples

    Consider the following generative model:

    e ~ Exponential(rate=[100,120])
    g ~ Gamma(concentration=e[0], rate=e[1])
    n ~ Normal(loc=0, scale=2.)
    m ~ Normal(loc=n, scale=g)
    for i = 1, ..., 12:
      x[i] ~ Bernoulli(logits=m)
    

    We can code this as:

    tfd = tfp.distributions
    joint = tfd.JointDistributionNamed(dict(
        e=             tfd.Exponential(rate=[100, 120]),
        g=lambda    e: tfd.Gamma(concentration=e[0], rate=e[1]),
        n=             tfd.Normal(loc=0, scale=2.),
        m=lambda n, g: tfd.Normal(loc=n, scale=g),
        x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12)
      ),
      batch_ndims=0,
      use_vectorized_map=True)
    

    Notice the 1:1 correspondence between "math" and "code". Further, notice that unlike JointDistributionSequential, there is no need to put the distribution-making functions in topologically sorted order nor is it ever necessary to use dummy arguments to skip dependencies.

    x = joint.sample()
    # ==> A 5-element `dict` of Tensors representing a draw/realization from each
    #     distribution.
    joint.log_prob(x)
    # ==> A scalar `Tensor` representing the total log prob under all five
    #     distributions.
    
    joint.resolve_graph()
    # ==> (('e', ()),
    #      ('g', ('e',)),
    #      ('n', ()),
    #      ('m', ('n', 'g')),
    #      ('x', ('m',)))
    

    Discussion

    JointDistributionNamed topologically sorts the distribution-making functions and calls each by feeding in all previously created dependencies. A distribution-maker must either be a:

  3. tfd.Distribution-like instance (e.g., e and n in the above example),

  4. Python callable (e.g., g, m, x in the above example).

    Regarding #1, an object is deemed "tfd.Distribution-like" if it has a sample, log_prob, and distribution properties, e.g., batch_shape, event_shape, dtype.

    Regarding #2, in addition to using a function (or lambda), supplying a TFD "class" is also permissible, this also being a "Python callable." For example, instead of writing: lambda loc, scale: tfd.Normal(loc=loc, scale=scale) one could have simply written tfd.Normal.

    Notice that directly providing a tfd.Distribution-like instance means there cannot exist a (dynamic) dependency on other distributions; it is "independent" both "computationally" and "statistically." The same is self-evidently true of zero-argument callables.

    A distribution instance depends on other distribution instances through the distribution making function's required arguments. The distribution makers' arguments are parameterized by samples from the corresponding previously constructed distributions. ("Previous" in the sense of a topological sorting of dependencies.)

    Vectorized sampling and model evaluation

    When a joint distribution's sample method is called with a sample_shape (or the log_prob method is called on an input with multiple sample dimensions) the model must be equipped to handle additional batch dimensions. This may be done manually, or automatically by passing use_vectorized_map=True. Manual vectorization has historically been the default, but we now recommend that most users enable automatic vectorization unless they are affected by a specific issue; some known issues are listed below.

    When using manually-vectorized joint distributions, each operation in the model must account for the possibility of batch dimensions in Distributions and their samples. By contrast, auto-vectorized models need only describe a single sample from the joint distribution; any batch evaluation is automated as required using tf.vectorized_map (vmap in JAX). In many cases this allows for significant simplications. For example, the following manually-vectorized tfd.JointDistributionSequential model:

    model = tfd.JointDistributionSequential([
        tfd.Normal(0., tf.ones([3])),
        tfd.Normal(0., 1.),
        lambda y, x: tfd.Normal(x[..., :2] + y[..., tf.newaxis], 1.)
      ])
    

    can be written in auto-vectorized form as

    model = tfd.JointDistributionSequential([
        tfd.Normal(0., tf.ones([3])),
        tfd.Normal(0., 1.),
        lambda y, x: tfd.Normal(x[:2] + y, 1.)
      ],
      use_vectorized_map=True)
    

    in which we were able to avoid explicitly accounting for batch dimensions when indexing and slicing computed quantities in the third line.

    Known limitations of automatic vectorization:

  • A small fraction of TensorFlow ops are unsupported; models that use an unsupported op will raise an error and must be manually vectorized.
  • Sampling large batches may be slow under automatic vectorization because TensorFlow's stateless samplers are currently converted using a non-vectorized while_loop. This limitation applies only in TensorFlow; vectorized samplers in JAX should be approximately as fast as manually vectorized code.
  • Calling sample_distributions with nontrivial sample_shape will raise an error if the model contains any distributions that are not registered as CompositeTensors (TFP's basic distributions are usually fine, but support for wrapper distributions like tfd.Sample is a work in progress).

    Batch semantics and (log-)densities

    tl;dr: pass batch_ndims=0 unless you have a good reason not to.

    Joint distributions now support 'auto-batching' semantics, in which the distribution's batch shape is derived by broadcasting the leftmost batch_ndims dimensions of its components' batch shapes. All remaining dimensions are considered to form a single 'event' of the joint distribution. If batch_ndims==0, then the joint distribution has batch shape [], and all component dimensions are treated as event shape. For example, the model

    jd = tfd.JointDistributionSequential([
        tfd.Normal(0., tf.ones([3])),
        lambda x: tfd.Normal(x[..., tf.newaxis], tf.ones([3, 2]))
      ],
      batch_ndims=0)
    

    creates a joint distribution with batch shape [] and event shape ([3], [3, 2]). The log-density of a sample always has shape batch_shape, so this guarantees that jd.log_prob(jd.sample()) will evaluate to a scalar value. We could alternately construct a joint distribution with batch shape [3] and event shape ([], [2]) by setting batch_ndims=1, in which case jd.log_prob(jd.sample()) would evaluate to a value of shape [3].

    Setting batch_ndims=None recovers the 'classic' batch semantics (currently still the default for backwards-compatibility reasons), in which the joint distribution's log_prob is computed by naively summing log densities from the component distributions. Since these component densities have shapes equal to the batch shapes of the individual components, to avoid broadcasting errors it is usually necessary to construct the components with identical batch shapes. For example, the component distributions in the model above have batch shapes of [3] and [3, 2] respectively, which would raise an error if summed directly, but can be aligned by wrapping with tfd.Independent, as in this model:

    jd = tfd.JointDistributionSequential([
        tfd.Normal(0., tf.ones([3])),
        lambda x: tfd.Independent(tfd.Normal(x[..., tf.newaxis], tf.ones([3, 2])),
                                  reinterpreted_batch_ndims=1)
      ],
      batch_ndims=None)
    

    Here the components both have batch shape [3], so jd.log_prob(jd.sample()) returns a value of shape [3], just as in the batch_ndims=1 case above. In fact, auto-batching semantics are equivalent to implicitly wrapping each component dist as tfd.Independent(dist, reinterpreted_batch_ndim=(dist.batch_shape.ndims - jd.batch_ndims)); the only vestigial difference is that under auto-batching semantics, the joint distribution has a single batch shape [3], while under the classic semantics the value of jd.batch_shape is a structure of the component batch shapes ([3], [3]). Such structured batch shapes will be deprecated in the future, since they are inconsistent with the definition of batch shapes used elsewhere in TFP.

    References

    [1] Dan Piponi, Dave Moore, and Joshua V. Dillon. Joint distributions for TensorFlow Probability. arXiv preprint arXiv:2001.11819_,

    1. https://arxiv.org/abs/2001.11819

If every element of model is a CompositeTensor or a callable, the resulting JointDistributionNamed is a CompositeTensor. Otherwise, a non-CompositeTensor _JointDistributionNamed instance is created.

model Python dict, collections.OrderedDict, or namedtuple of distribution-making functions each with required args corresponding only to other keys.
batch_ndims int Tensor number of batch dimensions. The batch_shapes of all component distributions must be such that the prefixes of length batch_ndims broadcast to a consistent joint batch shape. Default value: None.
use_vectorized_map Python bool. Whether to use tf.vectorized_map to automatically vectorize evaluation of the model. This allows the model specification to focus on drawing a single sample, which is often simpler, but some ops may not be supported. Default value: False.
validate_args Python bool. Whether to validate input with asserts. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
experimental_use_kahan_sum Python