Inferência aproximada para modelos STS com observações não gaussianas

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Este caderno demonstra o uso de ferramentas de inferência aproximada TFP para incorporar um modelo de observação (não gaussiano) ao ajustar e prever modelos de séries temporais estruturais (STS). Neste exemplo, usaremos um modelo de observação de Poisson para trabalhar com dados de contagem discretos.

import time
import matplotlib.pyplot as plt
import numpy as np

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

tf.enable_v2_behavior()

Dados Sintéticos

Primeiro, geraremos alguns dados de contagem sintéticos:

num_timesteps = 30
observed_counts = np.round(3 + np.random.lognormal(np.log(np.linspace(
    num_timesteps, 5, num=num_timesteps)), 0.20, size=num_timesteps)) 
observed_counts = observed_counts.astype(np.float32)
plt.plot(observed_counts)
[<matplotlib.lines.Line2D at 0x7f940ae958d0>]

png

Modelo

Especificaremos um modelo simples com uma tendência linear aleatória:

def build_model(approximate_unconstrained_rates):
  trend = tfp.sts.LocalLinearTrend(
      observed_time_series=approximate_unconstrained_rates)
  return tfp.sts.Sum([trend],
                     observed_time_series=approximate_unconstrained_rates)

Em vez de operar na série de tempo observada, este modelo irá operar na série de parâmetros da taxa de Poisson que governam as observações.

Como as taxas de Poisson devem ser positivas, usaremos um bijetor para transformar o modelo STS de valor real em uma distribuição sobre valores positivos. O Softplus transformação \(y = \log(1 + \exp(x))\) é uma escolha natural, uma vez que é quase linear para valores positivos, mas outras opções, tais como Exp (que transforma o passeio aleatório normal em um passeio aleatório lognormal) também são possíveis.

positive_bijector = tfb.Softplus()  # Or tfb.Exp()

# Approximate the unconstrained Poisson rate just to set heuristic priors.
# We could avoid this by passing explicit priors on all model params.
approximate_unconstrained_rates = positive_bijector.inverse(
    tf.convert_to_tensor(observed_counts) + 0.01)
sts_model = build_model(approximate_unconstrained_rates)

Para usar a inferência aproximada para um modelo de observação não gaussiano, codificaremos o modelo STS como uma TFP JointDistribution. As variáveis ​​aleatórias nesta distribuição conjunta são os parâmetros do modelo STS, a série temporal das taxas de Poisson latentes e as contagens observadas.

def sts_with_poisson_likelihood_model():
  # Encode the parameters of the STS model as random variables.
  param_vals = []
  for param in sts_model.parameters:
    param_val = yield param.prior
    param_vals.append(param_val)

  # Use the STS model to encode the log- (or inverse-softplus)
  # rate of a Poisson.
  unconstrained_rate = yield sts_model.make_state_space_model(
      num_timesteps, param_vals)
  rate = positive_bijector.forward(unconstrained_rate[..., 0])
  observed_counts = yield tfd.Poisson(rate, name='observed_counts')

model = tfd.JointDistributionCoroutineAutoBatched(sts_with_poisson_likelihood_model)

Preparação para inferência

Queremos inferir as quantidades não observadas no modelo, dadas as contagens observadas. Primeiro, condicionamos a densidade do log da junta nas contagens observadas.

pinned_model = model.experimental_pin(observed_counts=observed_counts)

Também precisaremos de um bijetor de restrição para garantir que a inferência respeite as restrições nos parâmetros do modelo STS (por exemplo, as escalas devem ser positivas).

constraining_bijector = pinned_model.experimental_default_event_space_bijector()

Inferência com HMC

Usaremos HMC (especificamente, NUTS) para amostrar a partir da articulação posterior sobre os parâmetros do modelo e taxas latentes.

Isso será significativamente mais lento do que ajustar um modelo STS padrão com HMC, pois além dos parâmetros do modelo (número relativamente pequeno de), também temos que inferir toda a série de taxas de Poisson. Portanto, executaremos um número relativamente pequeno de etapas; para aplicativos em que a qualidade da inferência é crítica, pode fazer sentido aumentar esses valores ou executar várias cadeias.

Configuração de amostrador

Primeiro vamos especificar um sampler, e depois usar sample_chain para executar esse núcleo de amostragem para amostras produzir.

sampler = tfp.mcmc.TransformedTransitionKernel(
    tfp.mcmc.NoUTurnSampler(
        target_log_prob_fn=pinned_model.unnormalized_log_prob,
        step_size=0.1),
    bijector=constraining_bijector)

adaptive_sampler = tfp.mcmc.DualAveragingStepSizeAdaptation(
    inner_kernel=sampler,
    num_adaptation_steps=int(0.8 * num_burnin_steps),
    target_accept_prob=0.75)

initial_state = constraining_bijector.forward(
    type(pinned_model.event_shape)(
        *(tf.random.normal(part_shape)
          for part_shape in constraining_bijector.inverse_event_shape(
              pinned_model.event_shape))))
# Speed up sampling by tracing with `tf.function`.
@tf.function(autograph=False, jit_compile=True)
def do_sampling():
  return tfp.mcmc.sample_chain(
      kernel=adaptive_sampler,
      current_state=initial_state,
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      trace_fn=None)

t0 = time.time()
samples = do_sampling()
t1 = time.time()
print("Inference ran in {:.2f}s.".format(t1-t0))
Inference ran in 24.83s.

Podemos verificar a integridade da inferência examinando os traços dos parâmetros. Nesse caso, eles parecem ter explorado várias explicações para os dados, o que é bom, embora mais amostras sejam úteis para avaliar o quão bem a cadeia está se misturando.

f = plt.figure(figsize=(12, 4))
for i, param in enumerate(sts_model.parameters):
  ax = f.add_subplot(1, len(sts_model.parameters), i + 1)
  ax.plot(samples[i])
  ax.set_title("{} samples".format(param.name))

png

Agora, para a recompensa: vamos ver a posterior sobre as taxas de Poisson! Também traçaremos o intervalo preditivo de 80% sobre as contagens observadas e podemos verificar se esse intervalo parece conter cerca de 80% das contagens que realmente observamos.

param_samples = samples[:-1]
unconstrained_rate_samples = samples[-1][..., 0]
rate_samples = positive_bijector.forward(unconstrained_rate_samples)

plt.figure(figsize=(10, 4))
mean_lower, mean_upper = np.percentile(rate_samples, [10, 90], axis=0)
pred_lower, pred_upper = np.percentile(np.random.poisson(rate_samples), 
                                       [10, 90], axis=0)

_ = plt.plot(observed_counts, color="blue", ls='--', marker='o', label='observed', alpha=0.7)
_ = plt.plot(np.mean(rate_samples, axis=0), label='rate', color="green", ls='dashed', lw=2, alpha=0.7)
_ = plt.fill_between(np.arange(0, 30), mean_lower, mean_upper, color='green', alpha=0.2)
_ = plt.fill_between(np.arange(0, 30), pred_lower, pred_upper, color='grey', label='counts', alpha=0.2)
plt.xlabel("Day")
plt.ylabel("Daily Sample Size")
plt.title("Posterior Mean")
plt.legend()
<matplotlib.legend.Legend at 0x7f93ffd35550>

png

Previsão

Para prever as contagens observadas, usaremos as ferramentas STS padrão para construir uma distribuição de previsão sobre as taxas latentes (no espaço irrestrito, mais uma vez, uma vez que o STS é projetado para modelar dados de valor real) e, em seguida, passaremos as previsões amostradas por meio de uma observação de Poisson modelo:

def sample_forecasted_counts(sts_model, posterior_latent_rates,
                             posterior_params, num_steps_forecast,
                             num_sampled_forecasts):

  # Forecast the future latent unconstrained rates, given the inferred latent
  # unconstrained rates and parameters.
  unconstrained_rates_forecast_dist = tfp.sts.forecast(sts_model,
    observed_time_series=unconstrained_rate_samples,
    parameter_samples=posterior_params,
    num_steps_forecast=num_steps_forecast)

  # Transform the forecast to positive-valued Poisson rates.
  rates_forecast_dist = tfd.TransformedDistribution(
      unconstrained_rates_forecast_dist,
      positive_bijector)

  # Sample from the forecast model following the chain rule:
  # P(counts) = P(counts | latent_rates)P(latent_rates)
  sampled_latent_rates = rates_forecast_dist.sample(num_sampled_forecasts)
  sampled_forecast_counts = tfd.Poisson(rate=sampled_latent_rates).sample()

  return sampled_forecast_counts, sampled_latent_rates

forecast_samples, rate_samples = sample_forecasted_counts(
   sts_model,
   posterior_latent_rates=unconstrained_rate_samples,
   posterior_params=param_samples,
   # Days to forecast:
   num_steps_forecast=30,
   num_sampled_forecasts=100)
forecast_samples = np.squeeze(forecast_samples)
def plot_forecast_helper(data, forecast_samples, CI=90):
  """Plot the observed time series alongside the forecast."""
  plt.figure(figsize=(10, 4))
  forecast_median = np.median(forecast_samples, axis=0)

  num_steps = len(data)
  num_steps_forecast = forecast_median.shape[-1]

  plt.plot(np.arange(num_steps), data, lw=2, color='blue', linestyle='--', marker='o',
           label='Observed Data', alpha=0.7)

  forecast_steps = np.arange(num_steps, num_steps+num_steps_forecast)

  CI_interval = [(100 - CI)/2, 100 - (100 - CI)/2]
  lower, upper = np.percentile(forecast_samples, CI_interval, axis=0)

  plt.plot(forecast_steps, forecast_median, lw=2, ls='--', marker='o', color='orange',
           label=str(CI) + '% Forecast Interval', alpha=0.7)
  plt.fill_between(forecast_steps,
                   lower,
                   upper, color='orange', alpha=0.2)

  plt.xlim([0, num_steps+num_steps_forecast])
  ymin, ymax = min(np.min(forecast_samples), np.min(data)), max(np.max(forecast_samples), np.max(data))
  yrange = ymax-ymin
  plt.title("{}".format('Observed time series with ' + str(num_steps_forecast) + ' Day Forecast'))
  plt.xlabel('Day')
  plt.ylabel('Daily Sample Size')
  plt.legend()
plot_forecast_helper(observed_counts, forecast_samples, CI=80)

png

Inferência VI

Variational inferência pode ser problemático quando inferir uma série de tempo integral, como nossas contagens aproximadas (em oposição a apenas os parâmetros de uma série de tempo, como nos modelos STS-padrão). A suposição padrão de que as variáveis ​​têm posteriores independentes é bastante errada, uma vez que cada passo de tempo é correlacionado com seus vizinhos, o que pode levar a subestimar a incerteza. Por esse motivo, o HMC pode ser uma escolha melhor para inferência aproximada em séries de tempo completo. No entanto, VI pode ser um pouco mais rápido e pode ser útil para prototipagem de modelo ou nos casos em que seu desempenho pode ser empiricamente mostrado como "bom o suficiente".

Para ajustar nosso modelo com VI, simplesmente construímos e otimizamos um posterior substituto:

surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=pinned_model.event_shape,
    bijector=constraining_bijector)
# Allow external control of optimization to reduce test runtimes.
num_variational_steps = 1000 # @param { isTemplate: true}
num_variational_steps = int(num_variational_steps)

t0 = time.time()
losses = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob,
                                        surrogate_posterior,
                                        optimizer=tf.optimizers.Adam(0.1),
                                        num_steps=num_variational_steps)
t1 = time.time()
print("Inference ran in {:.2f}s.".format(t1-t0))
Inference ran in 11.37s.
plt.plot(losses)
plt.title("Variational loss")
_ = plt.xlabel("Steps")

png

posterior_samples = surrogate_posterior.sample(50)
param_samples = posterior_samples[:-1]
unconstrained_rate_samples = posterior_samples[-1][..., 0]
rate_samples = positive_bijector.forward(unconstrained_rate_samples)

plt.figure(figsize=(10, 4))
mean_lower, mean_upper = np.percentile(rate_samples, [10, 90], axis=0)
pred_lower, pred_upper = np.percentile(
    np.random.poisson(rate_samples), [10, 90], axis=0)

_ = plt.plot(observed_counts, color='blue', ls='--', marker='o',
             label='observed', alpha=0.7)
_ = plt.plot(np.mean(rate_samples, axis=0), label='rate', color='green',
             ls='dashed', lw=2, alpha=0.7)
_ = plt.fill_between(
    np.arange(0, 30), mean_lower, mean_upper, color='green', alpha=0.2)
_ = plt.fill_between(np.arange(0, 30), pred_lower, pred_upper, color='grey',
    label='counts', alpha=0.2)
plt.xlabel('Day')
plt.ylabel('Daily Sample Size')
plt.title('Posterior Mean')
plt.legend()
<matplotlib.legend.Legend at 0x7f93ff4735c0>

png

forecast_samples, rate_samples = sample_forecasted_counts(
   sts_model,
   posterior_latent_rates=unconstrained_rate_samples,
   posterior_params=param_samples,
   # Days to forecast:
   num_steps_forecast=30,
   num_sampled_forecasts=100)
forecast_samples = np.squeeze(forecast_samples)
plot_forecast_helper(observed_counts, forecast_samples, CI=80)

png