Bayes Anahtar Noktası Analizi

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

Bu defter reimplements ve gelen Bayes “Change noktası analizi” örneğini uzanır pymc3 belgelerinde .

Önkoşullar

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15,8)
%config InlineBackend.figure_format = 'retina'
import numpy as np
import pandas as pd

veri kümesi

Veri kümesi dan burada . Not Bu örneğin başka bir versiyonu vardır yüzen , ancak veri “eksik” olan - bu durumda eksik değerler gerekiyordu. (Aksi takdirde, olabilirlik fonksiyonu tanımsız olacağından modeliniz ilk parametrelerini asla terk etmeyecektir.)

disaster_data = np.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
                           3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
                           2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
                           1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
                           0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
                           3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
                           0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
years = np.arange(1851, 1962)
plt.plot(years, disaster_data, 'o', markersize=8);
plt.ylabel('Disaster count')
plt.xlabel('Year')
plt.title('Mining disaster data set')
plt.show()

png

Olasılık Modeli

Model, bir “geçiş noktası” (örneğin, güvenlik düzenlemelerinin değiştiği bir yıl) ve bu geçiş noktasından önce ve sonra sabit (ancak potansiyel olarak farklı) oranlarla Poisson tarafından dağıtılan afet oranını varsayar.

Gerçek afet sayısı sabittir (gözlemlenir); bu modelin herhangi bir örneğinin hem geçiş noktasını hem de “erken” ve “geç” felaket oranlarını belirtmesi gerekecektir.

Orijinal modeli pymc3 dokümantasyon Örneğin :

\[ \begin{align*} (D_t|s,e,l)&\sim \text{Poisson}(r_t), \\ & \,\quad\text{with}\; r_t = \begin{cases}e & \text{if}\; t < s\\l &\text{if}\; t \ge s\end{cases} \\ s&\sim\text{Discrete Uniform}(t_l,\,t_h) \\ e&\sim\text{Exponential}(r_e)\\ l&\sim\text{Exponential}(r_l) \end{align*} \]

Ancak, ortalama afet oranı \(r_t\) switchpoint bir devamsızlığı vardır \(s\)değil türevlenebilir yapar. Bu nedenle, Hamilton Monte Carlo (HMC) algoritması için bir gradyan sinyali sağlar - fakat \(s\) önce sürekli olan, rasgele bir yürüyüş HMC en yedek bu örnekte, yüksek olasılık alanlarını bulmak için iyi bir yeterlidir.

İkinci bir model olarak, bir kullanarak, orijinal modeli değiştirme sigmoid “anahtar” geçiş türevlenebilir yapmak ve switchpoint için sürekli eşit bir dağılım kullanımı e ve l arasında \(s\). (Ortalama oranda bir “değişim” muhtemelen birkaç yıla yayılacağından, bu modelin gerçeğe daha yakın olduğu iddia edilebilir.) Yeni model şu şekildedir:

\[ \begin{align*} (D_t|s,e,l)&\sim\text{Poisson}(r_t), \\ & \,\quad \text{with}\; r_t = e + \frac{1}{1+\exp(s-t)}(l-e) \\ s&\sim\text{Uniform}(t_l,\,t_h) \\ e&\sim\text{Exponential}(r_e)\\ l&\sim\text{Exponential}(r_l) \end{align*} \]

Daha fazla bilgi yokluğunda biz varsayalım \(r_e = r_l = 1\) priors parametreler olarak. Her iki modeli de çalıştıracağız ve çıkarım sonuçlarını karşılaştıracağız.

def disaster_count_model(disaster_rate_fn):
  disaster_count = tfd.JointDistributionNamed(dict(
    e=tfd.Exponential(rate=1.),
    l=tfd.Exponential(rate=1.),
    s=tfd.Uniform(0., high=len(years)),
    d_t=lambda s, l, e: tfd.Independent(
        tfd.Poisson(rate=disaster_rate_fn(np.arange(len(years)), s, l, e)),
        reinterpreted_batch_ndims=1)
  ))
  return disaster_count

def disaster_rate_switch(ys, s, l, e):
  return tf.where(ys < s, e, l)

def disaster_rate_sigmoid(ys, s, l, e):
  return e + tf.sigmoid(ys - s) * (l - e)

model_switch = disaster_count_model(disaster_rate_switch)
model_sigmoid = disaster_count_model(disaster_rate_sigmoid)

Yukarıdaki kod, modeli JointDistributionSequential dağıtımları aracılığıyla tanımlar. disaster_rate işlevler bir dizi olarak adlandırılır [0, ..., len(years)-1] bir vektör üretmek için len(years) rastgele değişkenler - yıllar önce switchpoint olan early_disaster_rate sonra olanlar late_disaster_rate (modulo sigmoid geçiş).

Hedef log prob fonksiyonunun aklı başında olup olmadığının bir akıl sağlığı kontrolü:

def target_log_prob_fn(model, s, e, l):
  return model.log_prob(s=s, e=e, l=l, d_t=disaster_data)

models = [model_switch, model_sigmoid]
print([target_log_prob_fn(m, 40., 3., .9).numpy() for m in models])  # Somewhat likely result
print([target_log_prob_fn(m, 60., 1., 5.).numpy() for m in models])  # Rather unlikely result
print([target_log_prob_fn(m, -10., 1., 1.).numpy() for m in models]) # Impossible result
[-176.94559, -176.28717]
[-371.3125, -366.8816]
[-inf, -inf]

HMC, Bayes çıkarımı yapacak

Gerekli sonuçların ve devreye alma adımlarının sayısını tanımlarız; Kod çok örnek alındı tfp.mcmc.HamiltonianMonteCarlo belgelenmesi . Uyarlanabilir bir adım boyutu kullanır (aksi halde sonuç, seçilen adım boyutu değerine çok duyarlıdır). Zincirin ilk durumu olarak bir değerini kullanırız.

Yine de bu tam hikaye değil. Yukarıdaki model tanımına geri dönerseniz, bazı olasılık dağılımlarının tüm gerçek sayı doğrusunda iyi tanımlanmadığını fark edeceksiniz. Dolayısıyla HMC bir ile HMC çekirdeği sararak inceleyeceğiz bu alanı sınırlamak TransformedTransitionKernel belirtir ileri bijectors olasılık dağılımı (aşağıdaki kodda yorumlara bakınız) tanımlanır o etki üzerine gerçek sayılar dönüştürmek için söyledi.

num_results = 10000
num_burnin_steps = 3000

@tf.function(autograph=False, jit_compile=True)
def make_chain(target_log_prob_fn):
   kernel = tfp.mcmc.TransformedTransitionKernel(
       inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=target_log_prob_fn,
          step_size=0.05,
          num_leapfrog_steps=3),
       bijector=[
          # The switchpoint is constrained between zero and len(years).
          # Hence we supply a bijector that maps the real numbers (in a
          # differentiable way) to the interval (0;len(yers))
          tfb.Sigmoid(low=0., high=tf.cast(len(years), dtype=tf.float32)),
          # Early and late disaster rate: The exponential distribution is
          # defined on the positive real numbers
          tfb.Softplus(),
          tfb.Softplus(),
      ])
   kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        inner_kernel=kernel,
        num_adaptation_steps=int(0.8*num_burnin_steps))

   states = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=[
          # The three latent variables
          tf.ones([], name='init_switchpoint'),
          tf.ones([], name='init_early_disaster_rate'),
          tf.ones([], name='init_late_disaster_rate'),
      ],
      trace_fn=None,
      kernel=kernel)
   return states

switch_samples = [s.numpy() for s in make_chain(
    lambda *args: target_log_prob_fn(model_switch, *args))]
sigmoid_samples = [s.numpy() for s in make_chain(
    lambda *args: target_log_prob_fn(model_sigmoid, *args))]

switchpoint, early_disaster_rate, late_disaster_rate = zip(
    switch_samples, sigmoid_samples)

Her iki modeli de paralel olarak çalıştırın:

Sonucu görselleştirin

Sonucu, erken ve geç afet oranı ve geçiş noktası için arka dağılım örneklerinin histogramları olarak görselleştiriyoruz. Histogramlar, kesikli çizgiler olarak %95 ile güvenilir aralık sınırlarının yanı sıra numune medyanını temsil eden düz bir çizgiyle kaplanmıştır.

def _desc(v):
  return '(median: {}; 95%ile CI: $[{}, {}]$)'.format(
      *np.round(np.percentile(v, [50, 2.5, 97.5]), 2))

for t, v in [
    ('Early disaster rate ($e$) posterior samples', early_disaster_rate),
    ('Late disaster rate ($l$) posterior samples', late_disaster_rate),
    ('Switch point ($s$) posterior samples', years[0] + switchpoint),
]:
  fig, ax = plt.subplots(nrows=1, ncols=2, sharex=True)
  for (m, i) in (('Switch', 0), ('Sigmoid', 1)):
    a = ax[i]
    a.hist(v[i], bins=50)
    a.axvline(x=np.percentile(v[i], 50), color='k')
    a.axvline(x=np.percentile(v[i], 2.5), color='k', ls='dashed', alpha=.5)
    a.axvline(x=np.percentile(v[i], 97.5), color='k', ls='dashed', alpha=.5)
    a.set_title(m + ' model ' + _desc(v[i]))
  fig.suptitle(t)
  plt.show()

png

png

png