非ガウス観測によるSTSモデルの近似推論

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示 ノートブックをダウンロードする

このノートブックは、TFP近似推論ツールを使用して、構造時系列(STS)モデルをフィッティングおよび予測するときに(非ガウス)観測モデルを組み込む方法を示しています。この例では、ポアソン観測モデルを使用して離散カウントデータを処理します。

import time
import matplotlib.pyplot as plt
import numpy as np

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

tf.enable_v2_behavior()

合成データ

まず、いくつかの合成カウントデータを生成します。

num_timesteps = 30
observed_counts = np.round(3 + np.random.lognormal(np.log(np.linspace(
    num_timesteps, 5, num=num_timesteps)), 0.20, size=num_timesteps)) 
observed_counts = observed_counts.astype(np.float32)
plt.plot(observed_counts)
[<matplotlib.lines.Line2D at 0x7f940ae958d0>]

png

モデル

ランダムウォーク線形トレンドを持つ単純なモデルを指定します。

def build_model(approximate_unconstrained_rates):
  trend = tfp.sts.LocalLinearTrend(
      observed_time_series=approximate_unconstrained_rates)
  return tfp.sts.Sum([trend],
                     observed_time_series=approximate_unconstrained_rates)

このモデルは、観測された時系列を操作する代わりに、観測を管理する一連のポアソン率パラメーターを操作します。

ポアソン率は正でなければならないため、バイジェクターを使用して、実数値のSTSモデルを正の値の分布に変換します。 Softplus変換する$ Y = \ログ(1 + \ EXP(X))$それはほぼ正の値に対して線形であるので、自然な選択であるが、のような他の選択肢Exp対数正規ランダムウォークに通常のランダムウォークを変換( )も可能です。

positive_bijector = tfb.Softplus()  # Or tfb.Exp()

# Approximate the unconstrained Poisson rate just to set heuristic priors.
# We could avoid this by passing explicit priors on all model params.
approximate_unconstrained_rates = positive_bijector.inverse(
    tf.convert_to_tensor(observed_counts) + 0.01)
sts_model = build_model(approximate_unconstrained_rates)

非ガウス観測モデルに近似推論を使用するために、STSモデルをTFPJointDistributionとしてエンコードします。この同時分布の確率変数は、STSモデルのパラメーター、潜在ポアソン率の時系列、および観測されたカウントです。

def sts_with_poisson_likelihood_model():
  # Encode the parameters of the STS model as random variables.
  param_vals = []
  for param in sts_model.parameters:
    param_val = yield param.prior
    param_vals.append(param_val)

  # Use the STS model to encode the log- (or inverse-softplus)
  # rate of a Poisson.
  unconstrained_rate = yield sts_model.make_state_space_model(
      num_timesteps, param_vals)
  rate = positive_bijector.forward(unconstrained_rate[..., 0])
  observed_counts = yield tfd.Poisson(rate, name='observed_counts')

model = tfd.JointDistributionCoroutineAutoBatched(sts_with_poisson_likelihood_model)

推論の準備

観測されたカウントを前提として、モデル内の観測されていない量を推測します。まず、観測されたカウントに基づいて結合ログ密度を調整します。

pinned_model = model.experimental_pin(observed_counts=observed_counts)

また、推論がSTSモデルのパラメーターの制約を尊重するようにするための制約バイジェクターも必要です(たとえば、スケールは正である必要があります)。

constraining_bijector = pinned_model.experimental_default_event_space_bijector()

HMCによる推論

HMC(具体的にはNUTS)を使用して、モデルパラメータと潜在率を介して関節後方からサンプリングします。

これは、標準のSTSモデルをHMCに適合させるよりも大幅に遅くなります。これは、モデルの(比較的少数の)パラメーターに加えて、一連のポアソン率全体を推測する必要があるためです。したがって、比較的少数のステップで実行します。推論の品質が重要なアプリケーションでは、これらの値を増やすか、複数のチェーンを実行することが理にかなっている場合があります。

サンプラー構成

まず、サンプラーを指定し、[使用sample_chain農産物サンプルにカーネルをサンプリングすることを実行します。

sampler = tfp.mcmc.TransformedTransitionKernel(
    tfp.mcmc.NoUTurnSampler(
        target_log_prob_fn=pinned_model.unnormalized_log_prob,
        step_size=0.1),
    bijector=constraining_bijector)

adaptive_sampler = tfp.mcmc.DualAveragingStepSizeAdaptation(
    inner_kernel=sampler,
    num_adaptation_steps=int(0.8 * num_burnin_steps),
    target_accept_prob=0.75)

initial_state = constraining_bijector.forward(
    type(pinned_model.event_shape)(
        *(tf.random.normal(part_shape)
          for part_shape in constraining_bijector.inverse_event_shape(
              pinned_model.event_shape))))
# Speed up sampling by tracing with `tf.function`.
@tf.function(autograph=False, jit_compile=True)
def do_sampling():
  return tfp.mcmc.sample_chain(
      kernel=adaptive_sampler,
      current_state=initial_state,
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      trace_fn=None)

t0 = time.time()
samples = do_sampling()
t1 = time.time()
print("Inference ran in {:.2f}s.".format(t1-t0))
Inference ran in 24.83s.

パラメータトレースを調べることで、推論を健全性チェックできます。この場合、彼らはデータの複数の説明を調査したように見えます。これは良いことですが、チェーンがどの程度うまく混合されているかを判断するには、より多くのサンプルが役立ちます。

f = plt.figure(figsize=(12, 4))
for i, param in enumerate(sts_model.parameters):
  ax = f.add_subplot(1, len(sts_model.parameters), i + 1)
  ax.plot(samples[i])
  ax.set_title("{} samples".format(param.name))

png

さて、見返りとして、ポアソン率の事後分布を見てみましょう。また、観測されたカウントに対して80%の予測区間をプロットし、この区間に実際に観測されたカウントの約80%が含まれているように見えることを確認できます。

param_samples = samples[:-1]
unconstrained_rate_samples = samples[-1][..., 0]
rate_samples = positive_bijector.forward(unconstrained_rate_samples)

plt.figure(figsize=(10, 4))
mean_lower, mean_upper = np.percentile(rate_samples, [10, 90], axis=0)
pred_lower, pred_upper = np.percentile(np.random.poisson(rate_samples), 
                                       [10, 90], axis=0)

_ = plt.plot(observed_counts, color="blue", ls='--', marker='o', label='observed', alpha=0.7)
_ = plt.plot(np.mean(rate_samples, axis=0), label='rate', color="green", ls='dashed', lw=2, alpha=0.7)
_ = plt.fill_between(np.arange(0, 30), mean_lower, mean_upper, color='green', alpha=0.2)
_ = plt.fill_between(np.arange(0, 30), pred_lower, pred_upper, color='grey', label='counts', alpha=0.2)
plt.xlabel("Day")
plt.ylabel("Daily Sample Size")
plt.title("Posterior Mean")
plt.legend()
<matplotlib.legend.Legend at 0x7f93ffd35550>

png

予測

観測されたカウントを予測するには、標準のSTSツールを使用して潜在率の予測分布を作成し(STSは実数値データをモデル化するように設計されているため、制約のない空間で)、サンプリングされた予測をポアソン観測に渡します。モデル:

def sample_forecasted_counts(sts_model, posterior_latent_rates,
                             posterior_params, num_steps_forecast,
                             num_sampled_forecasts):

  # Forecast the future latent unconstrained rates, given the inferred latent
  # unconstrained rates and parameters.
  unconstrained_rates_forecast_dist = tfp.sts.forecast(sts_model,
    observed_time_series=unconstrained_rate_samples,
    parameter_samples=posterior_params,
    num_steps_forecast=num_steps_forecast)

  # Transform the forecast to positive-valued Poisson rates.
  rates_forecast_dist = tfd.TransformedDistribution(
      unconstrained_rates_forecast_dist,
      positive_bijector)

  # Sample from the forecast model following the chain rule:
  # P(counts) = P(counts | latent_rates)P(latent_rates)
  sampled_latent_rates = rates_forecast_dist.sample(num_sampled_forecasts)
  sampled_forecast_counts = tfd.Poisson(rate=sampled_latent_rates).sample()

  return sampled_forecast_counts, sampled_latent_rates

forecast_samples, rate_samples = sample_forecasted_counts(
   sts_model,
   posterior_latent_rates=unconstrained_rate_samples,
   posterior_params=param_samples,
   # Days to forecast:
   num_steps_forecast=30,
   num_sampled_forecasts=100)
forecast_samples = np.squeeze(forecast_samples)
def plot_forecast_helper(data, forecast_samples, CI=90):
  """Plot the observed time series alongside the forecast."""
  plt.figure(figsize=(10, 4))
  forecast_median = np.median(forecast_samples, axis=0)

  num_steps = len(data)
  num_steps_forecast = forecast_median.shape[-1]

  plt.plot(np.arange(num_steps), data, lw=2, color='blue', linestyle='--', marker='o',
           label='Observed Data', alpha=0.7)

  forecast_steps = np.arange(num_steps, num_steps+num_steps_forecast)

  CI_interval = [(100 - CI)/2, 100 - (100 - CI)/2]
  lower, upper = np.percentile(forecast_samples, CI_interval, axis=0)

  plt.plot(forecast_steps, forecast_median, lw=2, ls='--', marker='o', color='orange',
           label=str(CI) + '% Forecast Interval', alpha=0.7)
  plt.fill_between(forecast_steps,
                   lower,
                   upper, color='orange', alpha=0.2)

  plt.xlim([0, num_steps+num_steps_forecast])
  ymin, ymax = min(np.min(forecast_samples), np.min(data)), max(np.max(forecast_samples), np.max(data))
  yrange = ymax-ymin
  plt.title("{}".format('Observed time series with ' + str(num_steps_forecast) + ' Day Forecast'))
  plt.xlabel('Day')
  plt.ylabel('Daily Sample Size')
  plt.legend()
plot_forecast_helper(observed_counts, forecast_samples, CI=80)

png

VI推論

フルタイムのシリーズを推測する際に変分推論は、私たちの概算数量(標準STSモデルのように、時系列の単なるパラメータではなく)のように、問題となることがあります。各タイムステップは隣接するタイムステップと相関しており、不確実性を過小評価する可能性があるため、変数に独立した事後確率があるという標準的な仮定はまったく間違っています。このため、HMCは、完全な時系列にわたるおおよその推論に適している場合があります。ただし、VIはかなり高速である可能性があり、モデルのプロトタイピングや、そのパフォーマンスが「十分に良好」であることが経験的に示される場合に役立つ可能性があります。

モデルをVIに適合させるために、サロゲート後方を構築して最適化するだけです。

surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=pinned_model.event_shape,
    bijector=constraining_bijector)
# Allow external control of optimization to reduce test runtimes.
num_variational_steps = 1000 # @param { isTemplate: true}
num_variational_steps = int(num_variational_steps)

t0 = time.time()
losses = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob,
                                        surrogate_posterior,
                                        optimizer=tf.optimizers.Adam(0.1),
                                        num_steps=num_variational_steps)
t1 = time.time()
print("Inference ran in {:.2f}s.".format(t1-t0))
Inference ran in 11.37s.
plt.plot(losses)
plt.title("Variational loss")
_ = plt.xlabel("Steps")

png

posterior_samples = surrogate_posterior.sample(50)
param_samples = posterior_samples[:-1]
unconstrained_rate_samples = posterior_samples[-1][..., 0]
rate_samples = positive_bijector.forward(unconstrained_rate_samples)

plt.figure(figsize=(10, 4))
mean_lower, mean_upper = np.percentile(rate_samples, [10, 90], axis=0)
pred_lower, pred_upper = np.percentile(
    np.random.poisson(rate_samples), [10, 90], axis=0)

_ = plt.plot(observed_counts, color='blue', ls='--', marker='o',
             label='observed', alpha=0.7)
_ = plt.plot(np.mean(rate_samples, axis=0), label='rate', color='green',
             ls='dashed', lw=2, alpha=0.7)
_ = plt.fill_between(
    np.arange(0, 30), mean_lower, mean_upper, color='green', alpha=0.2)
_ = plt.fill_between(np.arange(0, 30), pred_lower, pred_upper, color='grey',
    label='counts', alpha=0.2)
plt.xlabel('Day')
plt.ylabel('Daily Sample Size')
plt.title('Posterior Mean')
plt.legend()
<matplotlib.legend.Legend at 0x7f93ff4735c0>

png

forecast_samples, rate_samples = sample_forecasted_counts(
   sts_model,
   posterior_latent_rates=unconstrained_rate_samples,
   posterior_params=param_samples,
   # Days to forecast:
   num_steps_forecast=30,
   num_sampled_forecasts=100)
forecast_samples = np.squeeze(forecast_samples)
plot_forecast_helper(observed_counts, forecast_samples, CI=80)

png