今日のローカルTensorFlowEverywhereイベントの出欠確認!
このページは Cloud Translation API によって翻訳されました。
Switch to English

文書化されていない実質的な感染は、新規コロナウイルス(SARS-CoV2)の迅速な普及を促進します

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロード

これは、Liらによる2020年3月16日の同名の論文のTensorFlowProbabilityポートです。 TensorFlow Probabilityプラットフォームで元の作成者の方法と結果を忠実に再現し、現代の疫学モデリングの設定におけるTFPの機能の一部を紹介します。 TensorFlowに移植すると、元のMatlabコードに比べて約10倍のスピードアップが得られます。また、TensorFlow Probabilityはベクトル化されたバッチ計算を広くサポートしているため、数百の独立したレプリケーションに拡張できます。

原紙

Ruiyun Li、Sen Pei、Bin Chen、Yimeng Song、Tao Zhang、Wan Yang、JeffreyShaman。文書化されていない実質的な感染は、新規コロナウイルス(SARS-CoV2)の迅速な拡散を促進します。 (2020)、doi: https//doi.org/10.1126/science.abb3221

要約: 「文書化されていない新規コロナウイルス(SARS-CoV2)感染の有病率と伝染性の推定は、この病気の全体的な有病率とパンデミックの可能性を理解するために重要です。ここでは、モビリティデータと併せて、中国内で報告された感染の観察結果を使用します。ネットワーク化された動的メタポピュレーションモデルとベイジアン推論により、SARS-CoV2に関連する重大な疫学的特徴を推測します。これには、文書化されていない感染の割合とその伝染性が含まれます。すべての感染の86%が文書化されていないと推定されます(95%CI:[82%–90%] )2020年1月23日以前の旅行制限。1人あたり、文書化されていない感染症の感染率は文書化された感染症の55%([46%–62%])でしたが、その数が多いため、文書化されていない感染症が79の感染源でした。文書化された症例の割合。これらの調査結果は、SARS-CoV2の急速な地理的広がりを説明し、このウイルスの封じ込めが特に困難であることを示しています。」

コードとデータへのGithubリンク

概要概要

このモデルは、「感受性」、「曝露」(感染しているがまだ感染していない)、「文書化されていない感染性」、および「最終的に文書化された感染性」のコンパートメントを備えコンパートメント疾患モデルです。 2つの注目すべき機能があります。375の中国の都市ごとに別々のコンパートメントがあり、人々が1つの都市から別の都市にどのように移動するかを想定しています。感染の報告が遅れるため、$ t $日に「最終的に文書化された感染性」になったケースは、確率的な後日まで、観察されたケースに表示されません。

このモデルは、文書化されていないケースが軽度で文書化されていないため、他のケースへの感染率が低いことを前提としています。元の論文で関心のある主なパラメータは、既存の感染の程度と、文書化されていない感染が病気の蔓延に及ぼす影響の両方を推定するために、文書化されていないケースの割合です。

このコラボは、ボトムアップスタイルのコードウォークスルーとして構成されています。順番に、

  • データを取り込んで簡単に調べ、
  • モデルの状態空間とダイナミクスを定義し、
  • Li et alに従って、モデルで推論を行うための一連の関数を構築し、
  • それらを呼び出して、結果を調べます。ネタバレ:紙と同じように出てきます。

インストールとPythonのインポート

pip3 install -q tf-nightly tfp-nightly
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import io
import requests
import time
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import samplers

tfd = tfp.distributions
tfes = tfp.experimental.sequential

データのインポート

githubからデータをインポートして、その一部を調べてみましょう。

r = requests.get('https://raw.githubusercontent.com/SenPei-CU/COVID-19/master/Data.zip')
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall('/tmp/')
raw_incidence = pd.read_csv('/tmp/data/Incidence.csv')
raw_mobility = pd.read_csv('/tmp/data/Mobility.csv')
raw_population = pd.read_csv('/tmp/data/pop.csv')

以下に、1日あたりの生の発生数を示します。最初の14日間(1月10日から1月23日)は、23日に旅行制限が設けられたため、最も関心があります。このホワイトペーパーでは、1月10〜23日と1月23日以降を別々に、さまざまなパラメータでモデル化することでこれを扱います。複製を以前の期間に制限します。

raw_incidence.drop('Date', axis=1)  # The 'Date' column is all 1/18/21
# Luckily the days are in order, starting on January 10th, 2020.

武漢の発生数を健全性チェックしましょう。

plt.plot(raw_incidence.Wuhan, '.-')
plt.title('Wuhan incidence counts over 1/10/20 - 02/08/20')
plt.show()

png

ここまでは順調ですね。これで、初期人口がカウントされます。

raw_population

また、どのエントリが武漢であるかを確認して記録しましょう。

raw_population['City'][169]
'Wuhan'
WUHAN_IDX = 169

そしてここに、異なる都市間のモビリティマトリックスが表示されます。これは、最初の14日間に異なる都市間を移動する人々の数のプロキシです。これは、2018年の旧正月シーズンにTencentから提供されたGPSレコードから導き出されたものです。 Li et alは、2020年シーズン中のモビリティを、未知の(推論の対象となる)定数係数$ \ theta $にこれを掛けたものとしてモデル化します。

raw_mobility

最後に、これらすべてを前処理して、消費できるnumpy配列にしましょう。

# The given populations are only "initial" because of intercity mobility during
# the holiday season.
initial_population = raw_population['Population'].to_numpy().astype(np.float32)

モビリティデータを[L、L、T]型のテンソルに変換します。ここで、Lは位置の数、Tはタイムステップの数です。

daily_mobility_matrices = []
for i in range(1, 15):
  day_mobility = raw_mobility[raw_mobility['Day'] == i]

  # Make a matrix of daily mobilities.
  z = pd.crosstab(
      day_mobility.Origin, 
      day_mobility.Destination, 
      values=day_mobility['Mobility Index'], aggfunc='sum', dropna=False)

  # Include every city, even if there are no rows for some in the raw data on
  # some day.  This uses the sort order of `raw_population`.
  z = z.reindex(index=raw_population['City'], columns=raw_population['City'], 
                fill_value=0)
  # Finally, fill any missing entries with 0. This means no mobility.
  z = z.fillna(0)
  daily_mobility_matrices.append(z.to_numpy())

mobility_matrix_over_time = np.stack(daily_mobility_matrices, axis=-1).astype(
    np.float32)

最後に、観察された感染を取り、[L、T]テーブルを作成します。

# Remove the date parameter and take the first 14 days.
observed_daily_infectious_count = raw_incidence.to_numpy()[:14, 1:]
observed_daily_infectious_count = np.transpose(
    observed_daily_infectious_count).astype(np.float32)

そして、私たちが望む形になったことを再確認してください。念のため、私たちは375の都市と14日間で活動しています。

print('Mobility Matrix over time should have shape (375, 375, 14): {}'.format(
    mobility_matrix_over_time.shape))
print('Observed Infectious should have shape (375, 14): {}'.format(
    observed_daily_infectious_count.shape))
print('Initial population should have shape (375): {}'.format(
    initial_population.shape))
Mobility Matrix over time should have shape (375, 375, 14): (375, 375, 14)
Observed Infectious should have shape (375, 14): (375, 14)
Initial population should have shape (375): (375,)

状態とパラメーターの定義

モデルの定義を始めましょう。私たちが再現しているモデルは、 SEIRモデルの変形です。この場合、次の時変状態があります。

  • $ S $:各都市でこの病気にかかりやすい人の数。
  • $ E $:各都市でこの病気にさらされているがまだ感染していない人の数。生物学的には、これは病気にかかることに相当し、すべての曝露された人々が最終的に感染性になります。
  • $ I ^ u $:各都市で感染性があるが文書化されていない人の数。モデルでは、これは実際には「文書化されない」ことを意味します。
  • $ I ^ r $:感染性があり、そのように文書化されている各都市の人々の数。 Li et alは遅延を報告しているので、$ I ^ r $は実際には、「ケースは将来のある時点で文書化されるほど深刻である」などに対応します。

以下に示すように、アンサンブル調整済みカルマンフィルター(EAKF)を時間的に前方に実行することにより、これらの状態を推測します。 EAKFの状態ベクトルは、これらの量ごとに1つの都市インデックス付きベクトルです。

モデルには、次の推定可能なグローバルな時不変パラメーターがあります。

  • $ \ beta $:文書化された感染者による感染率。
  • $ \ mu $:文書化されていない感染者による相対的な感染率。これは、製品$ \ mu \ beta $を通じて機能します。
  • $ \ theta $:都市間移動係数。これは、モビリティデータの過少報告(および2018年から2020年までの人口増加)を補正する1より大きい係数です。
  • $ Z $:平均潜伏期間(つまり、「露出」状態の時間)。
  • $ \ alpha $:これは(最終的に)文書化されるのに十分なほど深刻な感染の割合です。
  • $ D $:感染の平均期間(つまり、いずれかの「感染」状態の時間)。

状態のEAKFの周りの反復フィルタリングループを使用して、これらのパラメーターの点推定を推測します。

モデルは、推論されていない定数にも依存します。

  • $ M $:都市間モビリティマトリックス。これは時変であり、与えられたと推定されます。都市間の実際の人口移動を与えるために、推定パラメータ$ \ theta $によってスケーリングされることを思い出してください。
  • $ N $:各都市の総人数。初期母集団は与えられたとおりに取得され、母集団の時間変動は移動度数$ \ theta M $から計算されます。

まず、状態とパラメーターを保持するためのデータ構造をいくつか示します。

SEIRComponents = collections.namedtuple(
  typename='SEIRComponents',
  field_names=[
    'susceptible',              # S
    'exposed',                  # E
    'documented_infectious',    # I^r
    'undocumented_infectious',  # I^u
    # This is the count of new cases in the "documented infectious" compartment.
    # We need this because we will introduce a reporting delay, between a person
    # entering I^r and showing up in the observable case count data.
    # This can't be computed from the cumulative `documented_infectious` count,
    # because some portion of that population will move to the 'recovered'
    # state, which we aren't tracking explicitly.
    'daily_new_documented_infectious'])

ModelParams = collections.namedtuple(
    typename='ModelParams',
    field_names=[
      'documented_infectious_tx_rate',             # Beta
      'undocumented_infectious_tx_relative_rate',  # Mu
      'intercity_underreporting_factor',           # Theta
      'average_latency_period',                    # Z
      'fraction_of_documented_infections',         # Alpha
      'average_infection_duration'                 # D
    ]
)

また、パラメーターの値に対するLi etalの境界をコーディングします。

PARAMETER_LOWER_BOUNDS = ModelParams(
    documented_infectious_tx_rate=0.8,
    undocumented_infectious_tx_relative_rate=0.2,
    intercity_underreporting_factor=1.,
    average_latency_period=2.,
    fraction_of_documented_infections=0.02,
    average_infection_duration=2.
)

PARAMETER_UPPER_BOUNDS = ModelParams(
    documented_infectious_tx_rate=1.5,
    undocumented_infectious_tx_relative_rate=1.,
    intercity_underreporting_factor=1.75,
    average_latency_period=5.,
    fraction_of_documented_infections=1.,
    average_infection_duration=5.
)

SEIRダイナミクス

ここでは、パラメーターと状態の関係を定義します。

Li et al(補足資料、eqns 1-5)の時間ダイナミクス方程式は次のとおりです。

$ \ frac {dS_i} {dt} =-\ beta \ frac {S_i I_i ^ r} {N_i}-\ mu \ beta \ frac {S_i I_i ^ u} {N_i} + \ theta \ sum_k \ frac {M_ { ij} S_j} {N_j --I_j ^ r}-+ \ theta \ sum_k \ frac {M_ {ji} S_j} {N_i --I_i ^ r} $

$ \ frac {dE_i} {dt} = \ beta \ frac {S_i I_i ^ r} {N_i} + \ mu \ beta \ frac {S_i I_i ^ u} {N_i}-\ frac {E_i} {Z} + \ theta \ sum_k \ frac {M_ {ij} E_j} {N_j --I_j ^ r}-+ \ theta \ sum_k \ frac {M_ {ji} E_j} {N_i --I_i ^ r} $

$ \ frac {dI ^ r_i} {dt} = \ alpha \ frac {E_i} {Z}-\ frac {I_i ^ r} {D} $

$ \ frac {dI ^ u_i} {dt} =(1- \ alpha)\ frac {E_i} {Z}-\ frac {I_i ^ u} {D} + \ theta \ sum_k \ frac {M_ {ij} I_j ^ u} {N_j --I_j ^ r}-+ \ theta \ sum_k \ frac {M_ {ji} I ^ u_j} {N_i --I_i ^ r} $

$ N_i = N_i + \ theta \ sum_j M_ {ij}-\ theta \ sum_j M_ {ji} $

念のため、$ i $と$ j $の添え字は都市のインデックスを付けます。これらの方程式は、病気の時間発展をモデル化します。

  • より多くの感染につながる感染者との接触;
  • 「曝露された」状態から「感染性」状態の1つへの病気の進行。
  • 「感染性」状態から回復への疾患の進行。これは、モデル化された集団からの除去によってモデル化されます。
  • 曝露された、または文書化されていない感染者を含む、都市間の移動。そして
  • 都市間の移動による毎日の都市人口の時間変化。

Li et alに続いて、最終的に報告されるほど深刻な症例を持つ人々は都市間を移動しないと仮定します。

また、Li et alに従って、これらのダイナミクスを項ごとのポアソンノイズの影響を受けるものとして扱います。つまり、各項は実際にはポアソンのレートであり、サンプルから真の変化が得られます。ポアソンサンプルを(加算するのではなく)減算してもポアソン分布の結果が得られないため、ポアソンノイズは項ごとに異なります。

これらのダイナミクスを古典的な4次のルンゲクッタ積分器に合わせて進化させますが、最初にそれらを計算する関数を定義しましょう(ポアソンノイズのサンプリングを含む)。

def sample_state_deltas(
    state, population, mobility_matrix, params, seed, is_deterministic=False):
  """Computes one-step change in state, including Poisson sampling.

  Note that this is coded to support vectorized evaluation on arbitrary-shape
  batches of states.  This is useful, for example, for running multiple
  independent replicas of this model to compute credible intervals for the
  parameters.  We refer to the arbitrary batch shape with the conventional
  `B` in the parameter documentation below.  This function also, of course,
  supports broadcasting over the batch shape.

  Args:
    state: A `SEIRComponents` tuple with fields Tensors of shape
      B + [num_locations] giving the current disease state.
    population: A Tensor of shape B + [num_locations] giving the current city
      populations.
    mobility_matrix: A Tensor of shape B + [num_locations, num_locations] giving
      the current baseline inter-city mobility.
    params: A `ModelParams` tuple with fields Tensors of shape B giving the
      global parameters for the current EAKF run.
    seed: Initial entropy for pseudo-random number generation.  The Poisson
      sampling is repeatable by supplying the same seed.
    is_deterministic: A `bool` flag to turn off Poisson sampling if desired.

  Returns:
    delta: A `SEIRComponents` tuple with fields Tensors of shape
      B + [num_locations] giving the one-day changes in the state, according
      to equations 1-4 above (including Poisson noise per Li et al).
  """
  undocumented_infectious_fraction = state.undocumented_infectious / population
  documented_infectious_fraction = state.documented_infectious / population

  # Anyone not documented as infectious is considered mobile
  mobile_population = (population - state.documented_infectious)
  def compute_outflow(compartment_population):
    raw_mobility = tf.linalg.matvec(
        mobility_matrix, compartment_population / mobile_population)
    return params.intercity_underreporting_factor * raw_mobility
  def compute_inflow(compartment_population):
    raw_mobility = tf.linalg.matmul(
        mobility_matrix,
        (compartment_population / mobile_population)[..., tf.newaxis],
        transpose_a=True)
    return params.intercity_underreporting_factor * tf.squeeze(
        raw_mobility, axis=-1)

  # Helper for sampling the Poisson-variate terms.
  seeds = samplers.split_seed(seed, n=11)
  if is_deterministic:
    def sample_poisson(rate):
      return rate
  else:
    def sample_poisson(rate):
      return tfd.Poisson(rate=rate).sample(seed=seeds.pop())

  # Below are the various terms called U1-U12 in the paper. We combined the
  # first two, which should be fine; both are poisson so their sum is too, and
  # there's no risk (as there could be in other terms) of going negative.
  susceptible_becoming_exposed = sample_poisson(
      state.susceptible *
      (params.documented_infectious_tx_rate *
       documented_infectious_fraction +
       (params.undocumented_infectious_tx_relative_rate *
        params.documented_infectious_tx_rate) *
       undocumented_infectious_fraction))  # U1 + U2

  susceptible_population_inflow = sample_poisson(
      compute_inflow(state.susceptible))  # U3
  susceptible_population_outflow = sample_poisson(
      compute_outflow(state.susceptible))  # U4

  exposed_becoming_documented_infectious = sample_poisson(
      params.fraction_of_documented_infections *
      state.exposed / params.average_latency_period)  # U5
  exposed_becoming_undocumented_infectious = sample_poisson(
      (1 - params.fraction_of_documented_infections) *
      state.exposed / params.average_latency_period)  # U6

  exposed_population_inflow = sample_poisson(
      compute_inflow(state.exposed))  # U7
  exposed_population_outflow = sample_poisson(
      compute_outflow(state.exposed))  # U8

  documented_infectious_becoming_recovered = sample_poisson(
      state.documented_infectious /
      params.average_infection_duration)  # U9
  undocumented_infectious_becoming_recovered = sample_poisson(
      state.undocumented_infectious /
      params.average_infection_duration)  # U10

  undocumented_infectious_population_inflow = sample_poisson(
      compute_inflow(state.undocumented_infectious))  # U11
  undocumented_infectious_population_outflow = sample_poisson(
      compute_outflow(state.undocumented_infectious))  # U12

  # The final state_deltas
  return SEIRComponents(
      # Equation [1]
      susceptible=(-susceptible_becoming_exposed +
                   susceptible_population_inflow +
                   -susceptible_population_outflow),
      # Equation [2]
      exposed=(susceptible_becoming_exposed +
               -exposed_becoming_documented_infectious +
               -exposed_becoming_undocumented_infectious +
               exposed_population_inflow +
               -exposed_population_outflow),
      # Equation [3]
      documented_infectious=(
          exposed_becoming_documented_infectious +
          -documented_infectious_becoming_recovered),
      # Equation [4]
      undocumented_infectious=(
          exposed_becoming_undocumented_infectious +
          -undocumented_infectious_becoming_recovered +
          undocumented_infectious_population_inflow +
          -undocumented_infectious_population_outflow),
      # New to-be-documented infectious cases, subject to the delayed
      # observation model.
      daily_new_documented_infectious=exposed_becoming_documented_infectious)

これがインテグレーターです。これは完全に標準ですが、PRNGシードをsample_state_deltas関数にsample_state_deltasて、ルンゲクッタ法が要求する各部分ステップで独立したポアソンノイズを取得する点がsample_state_deltasます。

@tf.function(autograph=False)
def rk4_one_step(state, population, mobility_matrix, params, seed):
  """Implement one step of RK4, wrapped around a call to sample_state_deltas."""
  # One seed for each RK sub-step
  seeds = samplers.split_seed(seed, n=4)

  deltas = tf.nest.map_structure(tf.zeros_like, state)
  combined_deltas = tf.nest.map_structure(tf.zeros_like, state)

  for a, b in zip([1., 2, 2, 1.], [6., 3., 3., 6.]):
    next_input = tf.nest.map_structure(
        lambda x, delta, a=a: x + delta / a, state, deltas)
    deltas = sample_state_deltas(
        next_input,
        population,
        mobility_matrix,
        params,
        seed=seeds.pop(), is_deterministic=False)
    combined_deltas = tf.nest.map_structure(
        lambda x, delta, b=b: x + delta / b, combined_deltas, deltas)

  return tf.nest.map_structure(
      lambda s, delta: s + tf.round(delta),
      state, combined_deltas)

初期化

ここでは、論文から初期化スキームを実装します。

Li et alに続いて、私たちの推論スキームは、反復フィルタリング外部ループ(IF-EAKF)に囲まれた、アンサンブル調整カルマンフィルター内部ループになります。計算上、これは3種類の初期化が必要であることを意味します。

  • 内部EAKFの初期状態
  • 最初のEAKFの初期パラメータでもある外部IFの初期パラメータ
  • あるIF反復から次の反復へのパラメーターの更新。これは、最初のEAKF以外の各EAKFの初期パラメーターとして機能します。
def initialize_state(num_particles, num_batches, seed):
  """Initialize the state for a batch of EAKF runs.

  Args:
    num_particles: `int` giving the number of particles for the EAKF.
    num_batches: `int` giving the number of independent EAKF runs to
      initialize in a vectorized batch.
    seed: PRNG entropy.

  Returns:
    state: A `SEIRComponents` tuple with Tensors of shape [num_particles,
      num_batches, num_cities] giving the initial conditions in each
      city, in each filter particle, in each batch member.
  """
  num_cities = mobility_matrix_over_time.shape[-2]
  state_shape = [num_particles, num_batches, num_cities]
  susceptible = initial_population * np.ones(state_shape, dtype=np.float32)
  documented_infectious = np.zeros(state_shape, dtype=np.float32)
  daily_new_documented_infectious = np.zeros(state_shape, dtype=np.float32)

  # Following Li et al, initialize Wuhan with up to 2000 people exposed
  # and another up to 2000 undocumented infectious.
  rng = np.random.RandomState(seed[0] % (2**31 - 1))
  wuhan_exposed = rng.randint(
      0, 2001, [num_particles, num_batches]).astype(np.float32)
  wuhan_undocumented_infectious = rng.randint(
      0, 2001, [num_particles, num_batches]).astype(np.float32)

  # Also following Li et al, initialize cities adjacent to Wuhan with three
  # days' worth of additional exposed and undocumented-infectious cases,
  # as they may have traveled there before the beginning of the modeling
  # period.
  exposed = 3 * mobility_matrix_over_time[
      WUHAN_IDX, :, 0] * wuhan_exposed[
          ..., np.newaxis] / initial_population[WUHAN_IDX]
  undocumented_infectious = 3 * mobility_matrix_over_time[
      WUHAN_IDX, :, 0] * wuhan_undocumented_infectious[
          ..., np.newaxis] / initial_population[WUHAN_IDX]

  exposed[..., WUHAN_IDX] = wuhan_exposed
  undocumented_infectious[..., WUHAN_IDX] = wuhan_undocumented_infectious

  # Following Li et al, we do not remove the inital exposed and infectious
  # persons from the susceptible population.
  return SEIRComponents(
      susceptible=tf.constant(susceptible),
      exposed=tf.constant(exposed),
      documented_infectious=tf.constant(documented_infectious),
      undocumented_infectious=tf.constant(undocumented_infectious),
      daily_new_documented_infectious=tf.constant(daily_new_documented_infectious))

def initialize_params(num_particles, num_batches, seed):
  """Initialize the global parameters for the entire inference run.

  Args:
    num_particles: `int` giving the number of particles for the EAKF.
    num_batches: `int` giving the number of independent EAKF runs to
      initialize in a vectorized batch.
    seed: PRNG entropy.

  Returns:
    params: A `ModelParams` tuple with fields Tensors of shape
      [num_particles, num_batches] giving the global parameters
      to use for the first batch of EAKF runs.
  """
  # We have 6 parameters. We'll initialize with a Sobol sequence,
  # covering the hyper-rectangle defined by our parameter limits.
  halton_sequence = tfp.mcmc.sample_halton_sequence(
      dim=6, num_results=num_particles * num_batches, seed=seed)
  halton_sequence = tf.reshape(
      halton_sequence, [num_particles, num_batches, 6])
  halton_sequences = tf.nest.pack_sequence_as(
      PARAMETER_LOWER_BOUNDS, tf.split(
          halton_sequence, num_or_size_splits=6, axis=-1))
  def interpolate(minval, maxval, h):
    return (maxval - minval) * h + minval
  return tf.nest.map_structure(
      interpolate,
      PARAMETER_LOWER_BOUNDS, PARAMETER_UPPER_BOUNDS, halton_sequences)

def update_params(num_particles, num_batches,
                  prev_params, parameter_variance, seed):
  """Update the global parameters between EAKF runs.

  Args:
    num_particles: `int` giving the number of particles for the EAKF.
    num_batches: `int` giving the number of independent EAKF runs to
      initialize in a vectorized batch.
    prev_params: A `ModelParams` tuple of the parameters used for the previous
      EAKF run.
    parameter_variance: A `ModelParams` tuple specifying how much to drift
      each parameter.
    seed: PRNG entropy.

  Returns:
    params: A `ModelParams` tuple with fields Tensors of shape
      [num_particles, num_batches] giving the global parameters
      to use for the next batch of EAKF runs.
  """
  # Initialize near the previous set of parameters. This is the first step
  # in Iterated Filtering.
  seeds = tf.nest.pack_sequence_as(
      prev_params, samplers.split_seed(seed, n=len(prev_params)))
  return tf.nest.map_structure(
      lambda x, v, seed: x + tf.math.sqrt(v) * tf.random.stateless_normal([
          num_particles, num_batches, 1], seed=seed),
      prev_params, parameter_variance, seeds)

遅延

このモデルの重要な機能の1つは、感染が開始よりも遅く報告されるという事実を明確に考慮していることです。つまり、$ t $日に$ E $コンパートメントから$ I ^ r $コンパートメントに移動した人は、後日まで、報告された観察可能なケース数に表示されない可能性があります。

遅延はガンマ分布であると想定しています。 Li et alに従って、形状に1.85を使用し、レートをパラメーター化して、9日間の平均レポート遅延を生成します。

def raw_reporting_delay_distribution(gamma_shape=1.85, reporting_delay=9.):
  return tfp.distributions.Gamma(
      concentration=gamma_shape, rate=gamma_shape / reporting_delay)

観測は離散的であるため、生の(連続的な)遅延を最も近い日に切り上げます。また、データの範囲が有限であるため、1人の遅延分布は、残りの日数にわたって分類されます。したがって、代わりに多項遅延確率を事前に計算することにより、$ O(I ^ r)$ガンマをサンプリングするよりも効率的に都市ごとの予測観測値を計算できます。

def reporting_delay_probs(num_timesteps, gamma_shape=1.85, reporting_delay=9.):
  gamma_dist = raw_reporting_delay_distribution(gamma_shape, reporting_delay)
  multinomial_probs = [gamma_dist.cdf(1.)]
  for k in range(2, num_timesteps + 1):
    multinomial_probs.append(gamma_dist.cdf(k) - gamma_dist.cdf(k - 1))
  # For samples that are larger than T.
  multinomial_probs.append(gamma_dist.survival_function(num_timesteps))
  multinomial_probs = tf.stack(multinomial_probs)
  return multinomial_probs

これらの遅延を、毎日記録されている新しい感染数に実際に適用するためのコードは次のとおりです。

def delay_reporting(
    daily_new_documented_infectious, num_timesteps, t, multinomial_probs, seed):
  # This is the distribution of observed infectious counts from the current
  # timestep.

  raw_delays = tfd.Multinomial(
      total_count=daily_new_documented_infectious,
      probs=multinomial_probs).sample(seed=seed)

  # The last bucket is used for samples that are out of range of T + 1. Thus
  # they are not going to be observable in this model.
  clipped_delays = raw_delays[..., :-1]

  # We can also remove counts that are such that t + i >= T.
  clipped_delays = clipped_delays[..., :num_timesteps - t]
  # We finally shift everything by t. That means prepending with zeros.
  return tf.concat([
      tf.zeros(
          tf.concat([
              tf.shape(clipped_delays)[:-1], [t]], axis=0),
          dtype=clipped_delays.dtype),
      clipped_delays], axis=-1)

推論

まず、推論のためのいくつかのデータ構造を定義します。

特に、推論を実行しながら状態とパラメーターを一緒にパッケージ化する反復フィルタリングを実行する必要があります。そこで、 ParameterStatePairオブジェクトを定義します。

また、サイド情報をモデルにパッケージ化します。

ParameterStatePair = collections.namedtuple(
    'ParameterStatePair', ['state', 'params'])

# Info that is tracked and mutated but should not have inference performed over.
SideInfo = collections.namedtuple(
    'SideInfo', [
        # Observations at every time step.
        'observations_over_time',
        'initial_population',
        'mobility_matrix_over_time',
        'population',
        # Used for variance of measured observations.
        'actual_reported_cases',
        # Pre-computed buckets for the multinomial distribution.
        'multinomial_probs',
        'seed',
    ])

# Cities can not fall below this fraction of people
MINIMUM_CITY_FRACTION = 0.6

# How much to inflate the covariance by.
INFLATION_FACTOR = 1.1

INFLATE_FN = tfes.inflate_by_scaled_identity_fn(INFLATION_FACTOR)

これは、アンサンブルカルマンフィルター用にパッケージ化された完全な観測モデルです。

興味深い機能は、レポートの遅延です(以前のように計算されます)。アップストリームモデルは、各タイムステップで各都市のdaily_new_documented_infectiousdaily_new_documented_infectiousします。

# We observe the observed infections.
def observation_fn(t, state_params, extra):
  """Generate reported cases.

  Args:
    state_params: A `ParameterStatePair` giving the current parameters
      and state.
    t: Integer giving the current time.
    extra: A `SideInfo` carrying auxiliary information.

  Returns:
    observations: A Tensor of predicted observables, namely new cases
      per city at time `t`.
    extra: Update `SideInfo`.
  """
  # Undo padding introduced in `inference`.
  daily_new_documented_infectious = state_params.state.daily_new_documented_infectious[..., 0]
  # Number of people that we have already committed to become
  # observed infectious over time.
  # shape: batch + [num_particles, num_cities, time]
  observations_over_time = extra.observations_over_time
  num_timesteps = observations_over_time.shape[-1]

  seed, new_seed = samplers.split_seed(extra.seed, salt='reporting delay')

  daily_delayed_counts = delay_reporting(
      daily_new_documented_infectious, num_timesteps, t,
      extra.multinomial_probs, seed)
  observations_over_time = observations_over_time + daily_delayed_counts

  extra = extra._replace(
      observations_over_time=observations_over_time,
      seed=new_seed)

  # Actual predicted new cases, re-padded.
  adjusted_observations = observations_over_time[..., t][..., tf.newaxis]
  # Finally observations have variance that is a function of the true observations:
  return tfd.MultivariateNormalDiag(
      loc=adjusted_observations,
      scale_diag=tf.math.maximum(
          2., extra.actual_reported_cases[..., t][..., tf.newaxis] / 2.)), extra

ここでは、遷移ダイナミクスを定義します。セマンティック作業はすでに完了しています。ここでは、EAKFフレームワーク用にパッケージ化し、Li et alに従って、都市の人口が小さくなりすぎないようにクリップします。

def transition_fn(t, state_params, extra):
  """SEIR dynamics.

  Args:
    state_params: A `ParameterStatePair` giving the current parameters
      and state.
    t: Integer giving the current time.
    extra: A `SideInfo` carrying auxiliary information.

  Returns:
    state_params: A `ParameterStatePair` predicted for the next time step.
    extra: Updated `SideInfo`.
  """
  mobility_t = extra.mobility_matrix_over_time[..., t]
  new_seed, rk4_seed = samplers.split_seed(extra.seed, salt='Transition')
  new_state = rk4_one_step(
      state_params.state,
      extra.population,
      mobility_t,
      state_params.params,
      seed=rk4_seed)

  # Make sure population doesn't go below MINIMUM_CITY_FRACTION.
  new_population = (
      extra.population + state_params.params.intercity_underreporting_factor * (
          # Inflow
          tf.reduce_sum(mobility_t, axis=-2) - 
          # Outflow
          tf.reduce_sum(mobility_t, axis=-1)))
  new_population = tf.where(
      new_population < MINIMUM_CITY_FRACTION * extra.initial_population,
      extra.initial_population * MINIMUM_CITY_FRACTION,
      new_population)

  extra = extra._replace(population=new_population, seed=new_seed)

  # The Ensemble Kalman Filter code expects the transition function to return a distribution.
  # As the dynamics and noise are encapsulated above, we construct a `JointDistribution` that when
  # sampled, returns the values above.

  new_state = tfd.JointDistributionNamed(
      model=tf.nest.map_structure(lambda x: tfd.VectorDeterministic(x), new_state))
  params = tfd.JointDistributionNamed(
      model=tf.nest.map_structure(lambda x: tfd.VectorDeterministic(x), state_params.params))

  state_params = tfd.JointDistributionNamed(
      model=ParameterStatePair(state=new_state, params=params))

  return state_params, extra

最後に、推論方法を定義します。これは2つのループであり、外側のループは反復フィルタリングであり、内側のループはアンサンブル調整カルマンフィルターです。

# Use tf.function to speed up EAKF prediction and updates.
ensemble_kalman_filter_predict = tf.function(
    tfes.ensemble_kalman_filter_predict, autograph=False)
ensemble_adjustment_kalman_filter_update = tf.function(
    tfes.ensemble_adjustment_kalman_filter_update, autograph=False)

def inference(
    num_ensembles,
    num_batches,
    num_iterations,
    actual_reported_cases,
    mobility_matrix_over_time,
    seed=None,
    # This is how much to reduce the variance by in every iterative
    # filtering step.
    variance_shrinkage_factor=0.9,
    # Days before infection is reported.
    reporting_delay=9.,
    # Shape parameter of Gamma distribution.
    gamma_shape_parameter=1.85):
  """Inference for the Shaman, et al. model.

  Args:
    num_ensembles: Number of particles to use for EAKF.
    num_batches: Number of batches of IF-EAKF to run.
    num_iterations: Number of iterations to run iterative filtering.
    actual_reported_cases: `Tensor` of shape `[L, T]` where `L` is the number
      of cities, and `T` is the timesteps.
    mobility_matrix_over_time: `Tensor` of shape `[L, L, T]` which specifies the
      mobility between locations over time.
    variance_shrinkage_factor: Python `float`. How much to reduce the
      variance each iteration of iterated filtering.
    reporting_delay: Python `float`. How many days before the infection
      is reported.
    gamma_shape_parameter: Python `float`. Shape parameter of Gamma distribution
      of reporting delays.

  Returns:
    result: A `ModelParams` with fields Tensors of shape [num_batches],
      containing the inferred parameters at the final iteration.
  """
  print('Starting inference.')
  num_timesteps = actual_reported_cases.shape[-1]
  params_per_iter = []

  multinomial_probs = reporting_delay_probs(
      num_timesteps, gamma_shape_parameter, reporting_delay)

  seed = samplers.sanitize_seed(seed, salt='Inference')

  for i in range(num_iterations):
    start_if_time = time.time()
    seeds = samplers.split_seed(seed, n=4, salt='Initialize')
    if params_per_iter:
      parameter_variance = tf.nest.map_structure(
          lambda minval, maxval: variance_shrinkage_factor ** (
              2 * i) * (maxval - minval) ** 2 / 4.,
          PARAMETER_LOWER_BOUNDS, PARAMETER_UPPER_BOUNDS)
      params_t = update_params(
          num_ensembles,
          num_batches,
          prev_params=params_per_iter[-1],
          parameter_variance=parameter_variance,
          seed=seeds.pop())
    else:
      params_t = initialize_params(num_ensembles, num_batches, seed=seeds.pop())

    state_t = initialize_state(num_ensembles, num_batches, seed=seeds.pop())
    population_t = sum(x for x in state_t)
    observations_over_time = tf.zeros(
        [num_ensembles,
         num_batches,
         actual_reported_cases.shape[0], num_timesteps])

    extra = SideInfo(
        observations_over_time=observations_over_time,
        initial_population=tf.identity(population_t),
        mobility_matrix_over_time=mobility_matrix_over_time,
        population=population_t,
        multinomial_probs=multinomial_probs,
        actual_reported_cases=actual_reported_cases,
        seed=seeds.pop())

    # Clip states
    state_t = clip_state(state_t, population_t)
    params_t = clip_params(params_t, seed=seeds.pop())

    # Accrue the parameter over time. We'll be averaging that
    # and using that as our MLE estimate.
    params_over_time = tf.nest.map_structure(
        lambda x: tf.identity(x), params_t)

    state_params = ParameterStatePair(state=state_t, params=params_t)

    eakf_state = tfes.EnsembleKalmanFilterState(
        step=tf.constant(0), particles=state_params, extra=extra)

    for j in range(num_timesteps):
      seeds = samplers.split_seed(eakf_state.extra.seed, n=3)

      extra = extra._replace(seed=seeds.pop())

      # Predict step.

      # Inflate and clip.
      new_particles = INFLATE_FN(eakf_state.particles)
      state_t = clip_state(new_particles.state, eakf_state.extra.population)
      params_t = clip_params(new_particles.params, seed=seeds.pop())
      eakf_state = eakf_state._replace(
          particles=ParameterStatePair(params=params_t, state=state_t))

      eakf_predict_state = ensemble_kalman_filter_predict(eakf_state, transition_fn)

      # Clip the state and particles.
      state_params = eakf_predict_state.particles
      state_t = clip_state(
          state_params.state, eakf_predict_state.extra.population)
      state_params = ParameterStatePair(state=state_t, params=state_params.params)

      # We preprocess the state and parameters by affixing a 1 dimension. This is because for
      # inference, we treat each city as independent. We could also introduce localization by
      # considering cities that are adjacent.
      state_params = tf.nest.map_structure(lambda x: x[..., tf.newaxis], state_params)
      eakf_predict_state = eakf_predict_state._replace(particles=state_params)

      # Update step.

      eakf_update_state = ensemble_adjustment_kalman_filter_update(
          eakf_predict_state,
          actual_reported_cases[..., j][..., tf.newaxis],
          observation_fn)

      state_params = tf.nest.map_structure(
          lambda x: x[..., 0], eakf_update_state.particles)

      # Clip to ensure parameters / state are well constrained.
      state_t = clip_state(
          state_params.state, eakf_update_state.extra.population)

      # Finally for the parameters, we should reduce over all updates. We get
      # an extra dimension back so let's do that.
      params_t = tf.nest.map_structure(
          lambda x, y: x + tf.reduce_sum(y[..., tf.newaxis] - x, axis=-2, keepdims=True),
          eakf_predict_state.particles.params, state_params.params)
      params_t = clip_params(params_t, seed=seeds.pop())
      params_t = tf.nest.map_structure(lambda x: x[..., 0], params_t)

      state_params = ParameterStatePair(state=state_t, params=params_t)
      eakf_state = eakf_update_state
      eakf_state = eakf_state._replace(particles=state_params)

      # Flatten and collect the inferred parameter at time step t.
      params_over_time = tf.nest.map_structure(
          lambda s, x: tf.concat([s, x], axis=-1), params_over_time, params_t)

    est_params = tf.nest.map_structure(
        # Take the average over the Ensemble and over time.
        lambda x: tf.math.reduce_mean(x, axis=[0, -1])[..., tf.newaxis],
        params_over_time)
    params_per_iter.append(est_params)
    print('Iterated Filtering {} / {} Ran in: {:.2f} seconds'.format(
        i, num_iterations, time.time() - start_if_time))

  return tf.nest.map_structure(
      lambda x: tf.squeeze(x, axis=-1), params_per_iter[-1])

最終的な詳細:パラメーターと状態のクリッピングは、パラメーターが範囲内にあり、負でないことを確認することで構成されます。

def clip_state(state, population):
  """Clip state to sensible values."""
  state = tf.nest.map_structure(
      lambda x: tf.where(x < 0, 0., x), state)

  # If S > population, then adjust as well.
  susceptible = tf.where(state.susceptible > population, population, state.susceptible)
  return SEIRComponents(
      susceptible=susceptible,
      exposed=state.exposed,
      documented_infectious=state.documented_infectious,
      undocumented_infectious=state.undocumented_infectious,
      daily_new_documented_infectious=state.daily_new_documented_infectious)

def clip_params(params, seed):
  """Clip parameters to bounds."""
  def _clip(p, minval, maxval):
    return tf.where(
        p < minval,
        minval * (1. + 0.1 * tf.random.stateless_uniform(p.shape, seed=seed)),
        tf.where(p > maxval,
                 maxval * (1. - 0.1 * tf.random.stateless_uniform(
                     p.shape, seed=seed)), p))
  params = tf.nest.map_structure(
      _clip, params, PARAMETER_LOWER_BOUNDS, PARAMETER_UPPER_BOUNDS)

  return params

すべて一緒に実行する

# Let's sample the parameters.
#
# NOTE: Li et al. run inference 1000 times, which would take a few hours.
# Here we run inference 30 times (in a single, vectorized batch).
best_parameters = inference(
    num_ensembles=300,
    num_batches=30,
    num_iterations=10,
    actual_reported_cases=observed_daily_infectious_count,
    mobility_matrix_over_time=mobility_matrix_over_time)
Starting inference.
Iterated Filtering 0 / 10 Ran in: 26.65 seconds
Iterated Filtering 1 / 10 Ran in: 28.69 seconds
Iterated Filtering 2 / 10 Ran in: 28.06 seconds
Iterated Filtering 3 / 10 Ran in: 28.48 seconds
Iterated Filtering 4 / 10 Ran in: 28.57 seconds
Iterated Filtering 5 / 10 Ran in: 28.35 seconds
Iterated Filtering 6 / 10 Ran in: 28.35 seconds
Iterated Filtering 7 / 10 Ran in: 28.19 seconds
Iterated Filtering 8 / 10 Ran in: 28.58 seconds
Iterated Filtering 9 / 10 Ran in: 28.23 seconds

私たちの推論の結果。すべてのグローバルnum_batches最尤値をプロットして、 num_batches独立した推論の実行全体での変動を示します。これは、補足資料の表S1に対応しています。

fig, axs = plt.subplots(2, 3)
axs[0, 0].boxplot(best_parameters.documented_infectious_tx_rate,
                  whis=(2.5,97.5), sym='')
axs[0, 0].set_title(r'$\beta$')

axs[0, 1].boxplot(best_parameters.undocumented_infectious_tx_relative_rate,
                  whis=(2.5,97.5), sym='')
axs[0, 1].set_title(r'$\mu$')

axs[0, 2].boxplot(best_parameters.intercity_underreporting_factor,
                  whis=(2.5,97.5), sym='')
axs[0, 2].set_title(r'$\theta$')

axs[1, 0].boxplot(best_parameters.average_latency_period,
                  whis=(2.5,97.5), sym='')
axs[1, 0].set_title(r'$Z$')

axs[1, 1].boxplot(best_parameters.fraction_of_documented_infections,
                  whis=(2.5,97.5), sym='')
axs[1, 1].set_title(r'$\alpha$')

axs[1, 2].boxplot(best_parameters.average_infection_duration,
                  whis=(2.5,97.5), sym='')
axs[1, 2].set_title(r'$D$')
plt.tight_layout()

png