Module: tfp.layers

Probabilistic Layers.


class CategoricalMixtureOfOneHotCategorical: A OneHotCategorical mixture Keras layer from k * (1 + d) params.

class Convolution1DFlipout: 1D convolution layer (e.g. temporal convolution) with Flipout.

class Convolution1DReparameterization: 1D convolution layer (e.g. temporal convolution).

class Convolution2DFlipout: 2D convolution layer (e.g. spatial convolution over images) with Flipout.

class Convolution2DReparameterization: 2D convolution layer (e.g. spatial convolution over images).

class Convolution3DFlipout: 3D convolution layer (e.g. spatial convolution over volumes) with Flipout.

class Convolution3DReparameterization: 3D convolution layer (e.g. spatial convolution over volumes).

class DenseFlipout: Densely-connected layer class with Flipout estimator.

class DenseLocalReparameterization: Densely-connected layer class with local reparameterization estimator.

class DenseReparameterization: Densely-connected layer class with reparameterization estimator.

class DistributionLambda: Keras layer enabling plumbing TFP distributions through Keras models.

class IndependentBernoulli: An Independent-Bernoulli Keras layer from prod(event_shape) params.

class IndependentLogistic: An independent logistic Keras layer.

class IndependentNormal: An independent normal Keras layer.

class IndependentPoisson: An independent Poisson Keras layer.

class KLDivergenceAddLoss: Pass-through layer that adds a KL divergence penalty to the model loss.

class KLDivergenceRegularizer: Regularizer that adds a KL divergence penalty to the model loss.

class MixtureSameFamily: A mixture (same-family) Keras layer.

class MultivariateNormalTriL: A d-variate MVNTriL Keras layer from d + d * (d + 1) // 2 params.

class OneHotCategorical: A d-variate OneHotCategorical Keras layer from d params.


default_loc_scale_fn(...): Makes closure which creates loc, scale params from tf.get_variable.

default_mean_field_normal_fn(...): Creates a function to build Normal distributions with trainable params.

default_multivariate_normal_fn(...): Creates multivariate standard Normal distribution.