TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.sts.build_factored_surrogate_posterior

View source on GitHub

Build a variational posterior that factors over model parameters.

tfp.sts.build_factored_surrogate_posterior(
    model,
    batch_shape=(),
    seed=None,
    name=None
)

The surrogate posterior consists of independent Normal distributions for each parameter with trainable loc and scale, transformed using the parameter's bijector to the appropriate support space for that parameter.

Args:

  • model: An instance of StructuralTimeSeries representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape [b1, ..., bN].
  • batch_shape: Batch shape (Python tuple, list, or int) of initial states to optimize in parallel. Default value: (). (i.e., just run a single optimization).
  • seed: Python integer to seed the random number generator.
  • name: Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_factored_surrogate_posterior').

Returns:

  • variational_posterior: tfd.JointDistributionNamed defining a trainable surrogate posterior over model parameters. Samples from this distribution are Python dicts with Python str parameter names as keys.

Examples

Assume we've built a structural time-series model:

  day_of_week = tfp.sts.Seasonal(
      num_seasons=7,
      observed_time_series=observed_time_series,
      name='day_of_week')
  local_linear_trend = tfp.sts.LocalLinearTrend(
      observed_time_series=observed_time_series,
      name='local_linear_trend')
  model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
                      observed_time_series=observed_time_series)

To fit the model to data, we define a surrogate posterior and fit it by optimizing a variational bound:

  surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(
    model=model)
  loss_curve = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn=model.joint_log_prob(observed_time_series),
    surrogate_posterior=surrogate_posterior,
    num_steps=200)
  posterior_samples = surrogate_posterior.sample(50)

  # In graph mode, we would need to write:
  # with tf.control_dependencies([loss_curve]):
  #   posterior_samples = surrogate_posterior.sample(50)

For more control, we can also build and optimize a variational loss manually:

  @tf.function(autograph=False)  # Ensure the loss is computed efficiently
  def loss_fn():
    return tfp.vi.monte_carlo_variational_loss(
      model.joint_log_prob(observed_time_series),
      surrogate_posterior,
      sample_size=10)

  optimizer = tf.optimizers.Adam(learning_rate=0.1)
  for step in range(200):
    with tf.GradientTape() as tape:
      loss = loss_fn()
    grads = tape.gradient(loss, surrogate_posterior.trainable_variables)
    optimizer.apply_gradients(
      zip(grads, surrogate_posterior.trainable_variables))
    if step % 20 == 0:
      print('step {} loss {}'.format(step, loss))

  posterior_samples = surrogate_posterior.sample(50)