|View source on GitHub|
A generic probability distribution base class.
tf.compat.v1.distributions.Distribution( dtype, reparameterization_type, validate_args, allow_nan_stats, parameters=None, graph_parents=None, name=None )
Distribution is a base class for constructing and organizing properties
(e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
Subclasses are expected to implement a leading-underscore version of the
same-named function. The argument signature should be identical except for
the omission of
name="...". For example, to enable
name="log_prob") a subclass should implement
Subclasses can append to public-level docstrings by providing docstrings for their method specializations. For example:
@util.AppendDocstring("Some other details.") def _log_prob(self, value): ...
would add the string "Some other details." to the
docstring. This is implemented as a simple decorator to avoid python
linter complaining about missing Args/Returns/Raises sections in the
Broadcasting, batching, and shapes
All distributions support batches of independent distributions of that type. The batch shape is determined by broadcasting together the parameters.
The shape of arguments to
log_prob reflect this broadcasting, as does the return value of
sample_n_shape = [n] + batch_shape + event_shape, where
the shape of the
Tensor returned from
n is the number of
batch_shape defines how many independent distributions there are,
event_shape defines the shape of samples from each of those independent
distributions. Samples are independent along the
batch_shape dimensions, but
not necessarily so along the
event_shape dimensions (depending on the
particulars of the underlying distribution).
Uniform distribution as an example:
minval = 3.0 maxval = [[4.0, 6.0], [10.0, 12.0]] # Broadcasting: # This instance represents 4 Uniform distributions. Each has a lower bound at # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape. u = Uniform(minval, maxval) # `event_shape` is `TensorShape()`. event_shape = u.event_shape # `event_shape_t` is a `Tensor` which will evaluate to . event_shape_t = u.event_shape_tensor() # Sampling returns a sample per distribution. `samples` has shape # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5, # batch_shape=[2, 2], and event_shape=. samples = u.sample_n(5) # The broadcasting holds across methods. Here we use `cdf` as an example. The # same holds for `log_cdf` and the likelihood functions. # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the # shape of the `Uniform` instance. cum_prob_broadcast = u.cdf(4.0) # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting # occurred. cum_prob_per_dist = u.cdf([[4.0, 5.0], [6.0, 7.0]]) # INVALID as the `value` argument is not broadcastable to the distribution's # shape. cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
There are three important concepts associated with TensorFlow Distributions shapes:
- Event shape describes the shape of a single draw from the distribution;
it may be dependent across dimensions. For scalar distributions, the event
. For a 5-dimensional MultivariateNormal, the event shape is
- Batch shape describes independent, not identically distributed draws, aka a "collection" or "bunch" of distributions.
- Sample shape describes independent, identically distributed draws of batches from the distribution family.
The event shape and the batch shape are properties of a Distribution object,
whereas the sample shape is associated with a specific call to
For detailed usage examples of TensorFlow Distributions shapes, see this tutorial
Parameter values leading to undefined statistics or distributions.
Some distributions do not have well-defined statistics for all initialization
parameter values. For example, the beta distribution is parameterized by
positive real numbers
concentration0, and does not have
well-defined mode if
concentration1 < 1 or
concentration0 < 1.
The user is given the option of raising an exception or returning
a = tf.exp(tf.matmul(logits, weights_a)) b = tf.exp(tf.matmul(logits, weights_b)) # Will raise exception if ANY batch member has a < 1 or b < 1. dist = distributions.beta(a, b, allow_nan_stats=False) mode = dist.mode().eval() # Will return NaN for batch members with either a < 1 or b < 1. dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior mode = dist.mode().eval()
In all cases, an exception is raised if invalid parameters are passed, e.g.
# Will raise an exception if any Op is run. negative_a = -1.0 * a # beta distribution by definition has a > 0. dist = distributions.beta(negative_a, b, allow_nan_stats=True) dist.mean().eval()
The type of the event samples.