Constructs a distribution or bijector instance with trainable parameters.

Used in the notebooks

Used in the tutorials

This is a convenience method that instantiates a class using tf.Variables for its underlying trainable parameters. Parameters are randomly initialized, and transformed to enforce any domain constraints. This method assumes that the class exposes a parameter_properties method annotating its trainable parameters, and that the caller provides any additional (non-trainable) arguments required by the class.

cls Python class that implements cls.parameter_properties(), e.g., a TFP distribution (tfd.Normal) or bijector (tfb.Scale).
initial_parameters Optional str : Tensor dictionary specifying initial values for some or all of the trainable parameters. These values are used directly and must lie in the parameter domain, e.g., the initial value for a scale parameter must be positive. If no initial value is provided for a parameter, it will be initialized randomly as determined by the unconstrained_unit_fn. Default value: None.
unconstrained_init_fn Python callable that takes shape, seed, and dtype arguments, and returns a random real-valued Tensor of the specified shape and dtype. Any domain constraints, e.g. a requirement that a parameter must be positive, are applied by passing the sampled values through the default constraining bijectors specified in cls.parameter_properties(). Default value: tf.random.stateless_normal.
batch_and_event_shape Optional int Tensor desired shape of samples (for distributions) or inputs (for bijectors), used to determine the shape of the trainable parameters. Default value: ().
parameter_dtype Optional float dtype for trainable variables.
seed Optional random seed used to determine initial values. Default value: None.
**init_kwargs Additional keyword arguments passed to cls.__init__() to specify any non-trainable parameters. If a value is passed for an otherwise-trainable parameter---for example, trainable(tfd.Normal, scale=1.)---it will be taken as a fixed value and no variable will be constructed for that parameter.

trainable_instance an instance of cls parameterized by trainable variables.


Suppose we want to fit a normal distribution to observed data. We could of course just examine the empirical mean and standard deviation of the data:

samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
model = tfd.Normal(
  loc=tf.reduce_mean(samples),  # ==> 5.40
  scale=tf.math.reduce_std(sample))  # ==> 1.95

and this would be a very sensible approach. But that's boring, so instead, let's do way more work to get the same result. We'll build a trainable normal distribution, and explicitly optimize to find the maximum-likelihood estimate for the parameters given our data:

model = tfp.util.make_trainable(tfd.Normal)
losses = tfp.math.minimize(
  lambda: -model.log_prob(samples),
print('Fit Normal distribution with mean {} and stddev {}'.format(

In this trivial case, doing the explicit optimization has few advantages over the first approach in which we simply matched the empirical moments of the data. However, trainable distributions are useful more generally. For example, they can enable maximum-likelihood estimation of distributions when a moment-matching estimator is not available, and they can also serve as surrogate posteriors in variational inference.