View source on GitHub |
Constructs a distribution or bijector instance with trainable parameters.
tfp.experimental.util.make_trainable(
*args, seed=None, **kwargs
)
Used in the notebooks
Used in the tutorials |
---|
This is a convenience method that instantiates a class with trainable
parameters. Parameters are randomly initialized, and transformed to enforce
any domain constraints. This method assumes that the class exposes a
parameter_properties
method annotating its trainable parameters, and that
the caller provides any additional (non-trainable) arguments required by the
class.
Args | |
---|---|
cls
|
Python class that implements cls.parameter_properties() , e.g., a TFP
distribution (tfd.Normal ) or bijector (tfb.Scale ).
|
initial_parameters
|
a dictionary containing initial values for some or
all of the parameters to cls , OR a Python callable with signature
value = parameter_init_fn(parameter_name, shape, dtype, seed,
constraining_bijector) . If a dictionary is provided, any parameters not
specified will be initialized to a random value in their domain.
Default value: None (equivalent to {} ; all parameters are
initialized randomly).
|
batch_and_event_shape
|
Optional int Tensor desired shape of samples
(for distributions) or inputs (for bijectors), used to determine the shape
of the trainable parameters.
Default value: () .
|
parameter_dtype
|
Optional float dtype for trainable variables.
|
**init_kwargs
|
Additional keyword arguments passed to cls.__init__() to
specify any non-trainable parameters. If a value is passed for
an otherwise-trainable parameter---for example,
trainable(tfd.Normal, scale=1.) ---it will be taken as a fixed value and
no variable will be constructed for that parameter. seed: PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
instance
|
instance parameterized by trainable tf.Variable s.
|
Example
Suppose we want to fit a normal distribution to observed data. We could of course just examine the empirical mean and standard deviation of the data:
samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
model = tfd.Normal(
loc=tf.reduce_mean(samples), # ==> 5.40
scale=tf.math.reduce_std(sample)) # ==> 1.95
and this would be a very sensible approach. But that's boring, so instead, let's do way more work to get the same result. We'll build a trainable normal distribution, and explicitly optimize to find the maximum-likelihood estimate for the parameters given our data:
model = tfp.util.make_trainable(tfd.Normal)
losses = tfp.math.minimize(
lambda: -model.log_prob(samples),
optimizer=tf.optimizers.Adam(0.1),
num_steps=200)
print('Fit Normal distribution with mean {} and stddev {}'.format(
model.mean(),
model.stddev()))
In this trivial case, doing the explicit optimization has few advantages over the first approach in which we simply matched the empirical moments of the data. However, trainable distributions are useful more generally. For example, they can enable maximum-likelihood estimation of distributions when a moment-matching estimator is not available, and they can also serve as surrogate posteriors in variational inference.