View source on GitHub |
Returns stateless functions for building a variational posterior.
tfp.sts.build_factored_surrogate_posterior_stateless(
model, batch_shape=(), name=None
)
The surrogate posterior consists of independent Normal distributions for
each parameter with trainable loc
and scale
, transformed using the
parameter's bijector
to the appropriate support space for that parameter.
Examples
Assume we've built a structural time-series model:
day_of_week = tfp.sts.Seasonal(
num_seasons=7,
observed_time_series=observed_time_series,
name='day_of_week')
local_linear_trend = tfp.sts.LocalLinearTrend(
observed_time_series=observed_time_series,
name='local_linear_trend')
model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
observed_time_series=observed_time_series)
To (statelessly) fit the model to data, we construct init_fn
and
build_surrogate_fn
. init_fn
constructs an initial set of parameters
and build_surrogate_fn
is passed into
tfp.vi.fit_surrogate_posterior_stateless
to optimize a variational bound.
# This example only works in the JAX backend because it uses
# `optax` for stateless optimizers.
seed = tfp.random.sanitize_seed(jax.random.PRNGKey(0), salt='fit_stateless')
init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3)
init_fn, build_surrogate_fn = (
tfp.sts.build_factored_surrogate_posterior_stateless(model=model))
initial_parameters = init_fn(init_seed)
jd = model.joint_distribution(observed_time_series)
final_parameters, loss_curve = tfp.vi.fit_surrogate_posterior_stateless(
target_log_prob_fn=jd.log_prob,
initial_parameters=initial_parameters,
build_surrogate_posterior_fn=build_surrogate_fn,
optimizer=optax.adam(1e-4),
num_steps=200,
seed=fit_seed)
surrogate_posterior = build_surrogate_fn(final_parameters)
posterior_samples = surrogate_posterior.sample(50, seed=sample_seed)