tf.compat.v1.distributions.Dirichlet

Dirichlet distribution.

Inherits From: Distribution

The Dirichlet distribution is defined over the (k-1)-simplex using a positive, length-k vector concentration (k > 1). The Dirichlet is identically the Beta distribution when k = 2.

Mathematical Details

The Dirichlet is a distribution over the open (k-1)-simplex, i.e.,

S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.

The probability density function (pdf) is,

pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)

where:

  • x in S^{k-1}, i.e., the (k-1)-simplex,
  • concentration = alpha = [alpha_0, ..., alpha_{k-1}], alpha_j > 0,
  • Z is the normalization constant aka the multivariate beta function, and,
  • Gamma is the gamma function.

The concentration represents mean total counts of class occurrence, i.e.,

concentration = alpha = mean * total_concentration

where mean in S^{k-1} and total_concentration is a positive real number representing a mean total count.

Distribution parameters are automatically broadcast in all functions; see examples for details.

Samples of this distribution are reparameterized (pathwise differentiable). The derivatives are computed using the approach described in (Figurnov et al., 2018).

Examples

import tensorflow_probability as tfp
tfd = tfp.distributions

# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
dist = tfd.Dirichlet(alpha)