View source on GitHub |
Constructs a distribution or bijector instance with trainable parameters.
tfp.experimental.util.make_trainable_stateless(
cls,
initial_parameters=None,
batch_and_event_shape=(),
parameter_dtype=tf.float32,
**init_kwargs
)
This is a convenience method that instantiates a class with trainable
parameters. Parameters are randomly initialized, and transformed to enforce
any domain constraints. This method assumes that the class exposes a
parameter_properties
method annotating its trainable parameters, and that
the caller provides any additional (non-trainable) arguments required by the
class.
Returns | |
---|---|
init_fn
|
Python callable with signature initial_parameters = init_fn(seed) .
|
apply_fn
|
Python callable with signature instance = apply_fn(*parameters) .
|
Example
Suppose we want to fit a normal distribution to observed data. We could of course just examine the empirical mean and standard deviation of the data:
samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
model = tfd.Normal(
loc=tf.reduce_mean(samples), # ==> 5.40
scale=tf.math.reduce_std(sample)) # ==> 1.95
and this would be a very sensible approach. But that's boring, so instead, let's do way more work to get the same result. We'll build a trainable normal distribution, and explicitly optimize to find the maximum-likelihood estimate for the parameters given our data:
init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal)
import optax # JAX only.
mle_params, losses = tfp.math.minimize_stateless(
lambda *params: -apply_fn(params).log_prob(samples),
init=init_fn(),
optimizer=optax.adam(0.1),
num_steps=200)
model = apply_fn(mle_params)
print('Fit Normal distribution with mean {} and stddev {}'.format(
model.mean(),
model.stddev()))
In this trivial case, doing the explicit optimization has few advantages over the first approach in which we simply matched the empirical moments of the data. However, trainable distributions are useful more generally. For example, they can enable maximum-likelihood estimation of distributions when a moment-matching estimator is not available, and they can also serve as surrogate posteriors in variational inference.