Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

A Transformed Distribution.

Inherits From: Distribution

Used in the notebooks

Used in the tutorials

A TransformedDistribution models p(y) given a base distribution p(x), and a deterministic, invertible, differentiable transform, Y = g(X). The transform is typically an instance of the Bijector class and the base distribution is typically an instance of the Distribution class.

A Bijector is expected to implement the following functions:

  • forward,
  • inverse,
  • inverse_log_det_jacobian.

    The semantics of these functions are outlined in the Bijector documentation.

    We now describe how a TransformedDistribution alters the input/outputs of a Distribution associated with a random variable (rv) X.

    Write cdf(Y=y) for an absolutely continuous cumulative distribution function of random variable Y; write the probability density function pdf(Y=y) := d^k / (dy_1,...,dy_k) cdf(Y=y) for its derivative wrt to Y evaluated at y. Assume that Y = g(X) where g is a deterministic diffeomorphism, i.e., a non-random, continuous, differentiable, and invertible function. Write the inverse of g as X = g^{-1}(Y) and (J o g)(x) for the Jacobian of g evaluated at x.

    A TransformedDistribution implements the following operations:

    • sample Mathematically: Y = g(X) Programmatically: bijector.forward(distribution.sample(...))

    • log_prob Mathematically: (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) + (log o abs o det o J o g^{-1})(y) Programmatically: (distribution.log_prob(bijector.inverse(y)) + bijector.inverse_log_det_jacobian(y))

    • log_cdf Mathematically: (log o cdf)(Y=y) = (log o cdf o g^{-1})(y) Programmatically: distribution.log_cdf(bijector.inverse(x))

    • and similarly for: cdf, prob, log_survival_function, survival_function.

    Kullback-Leibler divergence is also well defined for TransformedDistribution instances that have matching bijectors. Bijector matching is performed via the Bijector.eq method, e.g., td1.bijector == td2.bijector, as part of the KL calculation. If the underlying bijectors do not match, a NotImplementedError is raised when calling kl_divergence. This is the same behavior as calling kl_divergence when two distributions do not have a registered KL divergence.

    A simple example constructing a Log-Normal distribution from a Normal distribution:

    tfd = tfp.distributions
    tfb = tfp.bijectors
    log_normal = tfd.TransformedDistribution(
      distribution=tfd.Normal(loc=0., scale=1.),

    A LogNormal made from callables:

    tfd = tfp.distributions
    tfb = tfp.bijectors
    log_normal = tfd.TransformedDistribution(
      distribution=tfd.Normal(loc=0., scale=1.),
          lambda y: -tf.reduce_sum(tf.log(y), axis=-1)),

    Another example constructing a Normal from a StandardNormal:

    tfd = tfp.distributions
    tfb = tfp.bijectors
    normal = tfd.TransformedDistribution(
      distribution=tfd.Normal(loc=0., scale=1.),

    A TransformedDistribution's batch_shape is derived by broadcasting the batch shapes of the base distribution and the bijector. The base distribution is then itself implicitly lifted to the broadcast batch shape. For example, in

    tfd = tfp.distributions
    tfb = tfp.bijectors
    batch_normal = tfd.TransformedDistribution(
      distribution=tfd.Normal(loc=0., scale=1.),
      bijector=tfb.Shift(shift=[-1., 0., 1.]),

    the base distribution has batch shape [], and the bijector applied to this distribution contributes a batch shape of [3] (obtained as bijector.experimental_batch_shape( x_event_ndims=tf.rank(distribution.event_shape)), yielding the broadcast shape batch_normal.batch_shape == [3]. Although sampling from the base distribution would ordinarily return just a single value, calling batch_normal.sample() will return a Tensor of 3 independent values, just as if the base distribution had explicitly followed the broadcast batch shape.

    The event_shape of a TransformedDistribution is the forward_event_shape of the bijector applied to the event_shape of the base distribution.

    tfd.Sample or tfd.Independent may be used to add extra IID dimensions to the event_shape of the base distribution before the bijector operates on it. The following example demonstrates how to construct a multivariate Normal as a TransformedDistribution, by adding a rank-1 IID dimension to the event_shape of a standard Normal and applying tfb.ScaleMatvecTriL.

    tfd = tfp.distributions
    tfb = tfp.bijectors
    # We will create two MVNs with batch_shape = event_shape = 2.
    mean = [[-1., 0],      # batch:0
            [0., 1]]       # batch:1
    chol_cov = [[[1., 0],
                 [0, 1]],  # batch:0
                [[1, 0],
                 [2, 2]]]  # batch:1
    mvn1 = tfd.TransformedDistribution(
            tfd.Normal(loc=[0., 0], scale=1.),  # base_dist.batch_shape == [2]
            sample_shape=[2])                   # base_dist.event_shape == [2]
    mvn2 = ds.MultivariateNormalTriL(loc=mean, scale_tril=chol_cov)
    # mvn1.log_prob(x) == mvn2.log_prob(x)

If both distribution and bijector are CompositeTensors, then the resulting TransformedDistribution instance is a CompositeTensor as well. Otherwise, a non-CompositeTensor _TransformedDistribution instance is created instead. Distribution subclasses that inherit from TransformedDistribution will also inherit from CompositeTensor.

distribution The base distribution instance to transform. Typically an instance of Distribution.
bijector The object responsible for calculating the transformation. Typically an instance of Bijector.
kwargs_split_fn Python callable which takes a kwargs dict and returns a tuple of kwargs dicts for each of the distribution and bijector parameters respectively. Default value: _default_kwargs_split_fn (i.e., lambda kwargs: (kwargs.get('distribution_kwargs', {}), kwargs.get('bijector_kwargs', {})))
validate_args Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs.
parameters Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operations.
name Python str name prefixed to Ops created by this class. Default: +

allow_nan_stats Python bool describing behavior when a stat is undefined.

Stats return +/- infinity when it makes sense. E.g., the variance of a Cauchy distribution is infinity. However, sometimes the statistic is undefined, e.g., if a distribution's pdf does not achieve a maximum within the support of the distribution, the mode is undefined. If the mean is undefined, then by definition the variance is undefined. E.g. the mean for Student's T for df = 1 is undefined (no clear way to say it is either + or - infinity), so the variance = E[(X - mean)**2] is also undefined.

batch_shape Shape of a single sample from a single event index as a TensorShape.

May be partially defined or unknown.

The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.

bijector Function transforming x => y.
distribution Base distribution, p(x).
dtype The DType of Tensors handled by this Distribution.
event_shape Shape of a single sample from a single batch as a TensorShape.

May be partially defined or unknown.


experimental_shard_axis_names The list or structure of lists of active shard axis names.
name Name prepended to all ops created by this Distribution.
name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.

parameters Dictionary of parameters used to instantiate this Distribution.
reparameterization_type Describes how samples from the distribution are reparameterized.

Currently this is one of the static instances tfd.FULLY_REPARAMETERIZED or tfd.NOT_REPARAMETERIZED.

submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
list(b.submodules) == [c]
list(c.submodules) == []

trainable_variables Sequence of trainable variables owned by this module and its submodules.

validate_args Python bool indicating possibly expensive checks are enabled.
variables Sequence of variables owned by this module and its submodules.