此页面由 Cloud Translation API 翻译。
Switch to English

建模COVID-19蔓延在欧洲和干预的效果

查看上TensorFlow.org 在谷歌Colab运行 GitHub上查看源代码 下载笔记本

为了减缓COVID-19在2020年年初的蔓延,欧洲国家采取非药物干预,如非必要业务的关闭,个别病例隔离,旅行禁令等措施,鼓励社会距离。国子监COVID-19响应小组分析了在他们的论文,这些措施的效果“估计感染人数和非药物干预对COVID-19在11个欧洲国家的影响” ,使用贝叶斯分层模型与机械相结合流行病学模型。

这Colab包含TensorFlow概率(TFP)实现,分析,安排如下:

  • “模型设置”定义了疾病传播和造成的死亡,在模型参数贝叶斯先验分布,和死亡参数值的条件的数量的分布的流行病模型。
  • “数据预处理”在干预每个国家的时机和类型的数据负载,死亡人数随着时间的推移,以及估计死亡率为那些受感染的罪名。
  • “型号推论”构建贝叶斯分层模型和运行哈密顿蒙特卡洛(HMC),以从在参数后验分布的样品。
  • “结果”节目后部用于感兴趣量预测分布如预测死亡,并且在没有干预的反死亡。

本文中发现的证据表明,国家已经成功地减少了由于每个感染者($ R_t $)发送的新感染人数,但可信区间包含$ R_t = 1 $(高于该疫情继续蔓延的值),它还为时过早借鉴干预的有效性有力的结论。用于造纸的斯坦代码是在作者的Github上库,这Colab再现第2版

pip3 install -q git+git://github.com/arviz-devs/arviz.git
pip3 install -q tf-nightly tfp-nightly
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
from pprint import pprint

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import prefer_static as ps

tf.enable_v2_behavior()

# Globally Enable XLA.
# tf.config.optimizer.set_jit(True)

try:
  physical_devices = tf.config.list_physical_devices('GPU')
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

tfb = tfp.bijectors
tfd = tfp.distributions

DTYPE = np.float32
 

1模型设置

1.1机械模型的感染和死亡

感染模型模拟随着时间的推移在每一个国家的感染人数。输入数据是定时和干预的种类,种群大小和初始病例。参数控制措施的有效性和疾病传播的速度。对死亡的预期数量该机型适用于一人死亡速度与预测感染。

感染模型执行以前每天感染与串行间隔distributution(被感染和传染别人之间的分配超过的天数)的卷积。在每个时间段,在时间$ T $,$ $ N_T新感染人数,计算公式为

\开始{方程} \ sum_ {I = 0} ^ {T-1} n_i个\ mu_t分机{P}(\文本I {从某人在感染抓} | \文本{新近在感染}吨)\端{方程}其中$ \ mu_t = R_t $和条件概率被存储在conv_serial_interval ,定义如下。

预计死亡人数为模型进行日常感染的卷积和感染和死亡之间的天的分布。也就是说,在一天$ T $预计死亡人数的计算公式为

\开始{方程} \ sum_ {I = 0} ^ {T-1} n_i个\文本{P(一天$ T $死亡|感染一天$ I $)} \端{方程}其中的条件概率被存储在conv_fatality_rate ,定义如下。

 from tensorflow_probability.python.mcmc.internal import util as mcmc_util

def predict_infections(
    intervention_indicators, population, initial_cases, mu, alpha_hier,
    conv_serial_interval, initial_days, total_days):
  """Predict the number of infections by forward-simulation.

  Args:
    intervention_indicators: Binary array of shape
      `[num_countries, total_days, num_interventions]`, in which `1` indicates
      the intervention is active in that country at that time and `0` indicates
      otherwise.
    population: Vector of length `num_countries`. Population of each country.
    initial_cases: Array of shape `[batch_size, num_countries]`. Number of cases
      in each country at the start of the simulation.
    mu: Array of shape `[batch_size, num_countries]`. Initial reproduction rate
      (R_0) by country.
    alpha_hier: Array of shape `[batch_size, num_interventions]` representing
      the effectiveness of interventions.
    conv_serial_interval: Array of shape
      `[total_days - initial_days, total_days]` output from
      `make_conv_serial_interval`. Convolution kernel for serial interval
      distribution.
    initial_days: Integer, number of sequential days to seed infections after
      the 10th death in a country. (N0 in the authors' Stan code.)
    total_days: Integer, number of days of observed data plus days to forecast.
      (N2 in the authors' Stan code.)
  Returns:
    predicted_infections: Array of shape
      `[total_days, batch_size, num_countries]`. (Batched) predicted number of
      infections over time and by country.
  """
  alpha = alpha_hier - tf.cast(np.log(1.05) / 6.0, DTYPE)

  # Multiply the effectiveness of each intervention in each country (alpha)
  # by the indicator variable for whether the intervention was active and sum
  # over interventions, yielding an array of shape
  # [total_days, batch_size, num_countries] that represents the total effectiveness of
  # all interventions in each country on each day (for a batch of data).
  linear_prediction = tf.einsum(
      'ijk,...k->j...i', intervention_indicators, alpha)

  # Adjust the reproduction rate per country downward, according to the
  # effectiveness of the interventions.
  rt = mu * tf.exp(-linear_prediction, name='reproduction_rate')

  # Initialize storage array for daily infections and seed it with initial
  # cases.
  daily_infections = tf.TensorArray(
      dtype=DTYPE, size=total_days, element_shape=initial_cases.shape)
  for i in range(initial_days):
    daily_infections = daily_infections.write(i, initial_cases)

  # Initialize cumulative cases.
  init_cumulative_infections = initial_cases * initial_days

  # Simulate forward for total_days days.
  cond = lambda i, *_: i < total_days
  def body(i, prev_daily_infections, prev_cumulative_infections):
    # The probability distribution over days j that someone infected on day i
    # caught the virus from someone infected on day j.
    p_infected_on_day = tf.gather(
        conv_serial_interval, i - initial_days, axis=0)

    # Multiply p_infected_on_day by the number previous infections each day and
    # by mu, and sum to obtain new infections on day i. Mu is adjusted by
    # the fraction of the population already infected, so that the population
    # size is the upper limit on the number of infections.
    prev_daily_infections_array = prev_daily_infections.stack()
    to_sum = prev_daily_infections_array * mcmc_util.left_justified_expand_dims_like(
        p_infected_on_day, prev_daily_infections_array)
    convolution = tf.reduce_sum(to_sum, axis=0)
    rt_adj = (
        (population - prev_cumulative_infections) / population
        ) * tf.gather(rt, i)
    new_infections = rt_adj * convolution

    # Update the prediction array and the cumulative number of infections.
    daily_infections = prev_daily_infections.write(i, new_infections)
    cumulative_infections = prev_cumulative_infections + new_infections
    return i + 1, daily_infections, cumulative_infections

  _, daily_infections_final, last_cumm_sum = tf.while_loop(
      cond, body,
      (initial_days, daily_infections, init_cumulative_infections),
      maximum_iterations=(total_days - initial_days))
  return daily_infections_final.stack()

def predict_deaths(predicted_infections, ifr_noise, conv_fatality_rate):
  """Expected number of reported deaths by country, by day.

  Args:
    predicted_infections: Array of shape
      `[total_days, batch_size, num_countries]` output from
      `predict_infections`.
    ifr_noise: Array of shape `[batch_size, num_countries]`. Noise in Infection
      Fatality Rate (IFR).
    conv_fatality_rate: Array of shape
      `[total_days - 1, total_days, num_countries]`. Convolutional kernel for
      calculating fatalities, output from `make_conv_fatality_rate`.
  Returns:
    predicted_deaths: Array of shape `[total_days, batch_size, num_countries]`.
      (Batched) predicted number of deaths over time and by country.
  """
  # Multiply the number of infections on day j by the probability of death
  # on day i given infection on day j, and sum over j. This yields the expected
  result_remainder = tf.einsum(
      'i...j,kij->k...j', predicted_infections, conv_fatality_rate) * ifr_noise

  # Concatenate the result with a vector of zeros so that the first day is
  # included.
  result_temp = 1e-15 * predicted_infections[:1]
  return tf.concat([result_temp, result_remainder], axis=0)
 

1.2以上之前的参数值

在这里我们定义了模型参数的联合先验分布。许多参数值被假设为是独立的,使得在现有可表示为:

$ \文本P(\ tau蛋白,Y,\ PSI,\κ,\亩,\阿尔法)= \文本P(\ TAU)\文本P(Y | \ TAU)\文本P(\ PSI)\文本P( \卡帕)\文本p(\亩| \卡帕)\文本p(\阿尔法)\文本p(\小量)$

其中:

  • $ \ tau蛋白$是指数分布在每个国家初始情况下,$ Y = Y_1,... Y_ {NUM&#95;国家}数量的共享率参数$。
  • $ \ $ PSI是在死亡人数负二项分布的参数。
  • $ \卡帕$是HalfNormal分布在每个国家的初始再现数目的共享尺度参数,$ \亩= \ mu_1,...,\ mu_ {NUM&#95;国家} $(表示附加病例数每个感染者传播)。
  • $ \阿尔法= \ alpha_1,...,\ alpha_6 $是这六个干预措施的有效性。
  • $ \ $小量(称为ifr_noise在代码中,作者斯坦代码之后)的噪音在感染病死率(IFR)。

我们将这一模式作为TFP JointDistribution,类型TFP的分布,使概率图模型的表达。

 def make_jd_prior(num_countries, num_interventions):
  return tfd.JointDistributionSequentialAutoBatched([
      # Rate parameter for the distribution of initial cases (tau).
      tfd.Exponential(rate=tf.cast(0.03, DTYPE)),

      # Initial cases for each country.
      lambda tau: tfd.Sample(
          tfd.Exponential(rate=tf.cast(1, DTYPE) / tau),
          sample_shape=num_countries),

      # Parameter in Negative Binomial model for deaths (psi).
      tfd.HalfNormal(scale=tf.cast(5, DTYPE)),

      # Parameter in the distribution over the initial reproduction number, R_0
      # (kappa).
      tfd.HalfNormal(scale=tf.cast(0.5, DTYPE)),

      # Initial reproduction number, R_0, for each country (mu).
      lambda kappa: tfd.Sample(
          tfd.TruncatedNormal(loc=3.28, scale=kappa, low=1e-5, high=1e5),
          sample_shape=num_countries),

      # Impact of interventions (alpha; shared for all countries).
      tfd.Sample(
          tfd.Gamma(tf.cast(0.1667, DTYPE), 1), sample_shape=num_interventions),

      # Multiplicative noise in Infection Fatality Rate.
      tfd.Sample(
          tfd.TruncatedNormal(
              loc=tf.cast(1., DTYPE), scale=0.1, low=1e-5, high=1e5),
              sample_shape=num_countries)
  ])

 

1.3似然观察到的死亡的条件参数值

可能性模型表示$ P(\ {文字死亡} | \ tau蛋白,Y \ PSI,\卡帕,\亩,\α,\小量)$。它适用的机型感染和死亡人数预计在参数条件的数量,并假设实际死亡遵循负二项分布。

 def make_likelihood_fn(
    intervention_indicators, population, deaths,
    infection_fatality_rate, initial_days, total_days):

  # Create a mask for the initial days of simulated data, as they are not
  # counted in the likelihood.
  observed_deaths = tf.constant(deaths.T[np.newaxis, ...], dtype=DTYPE)
  mask_temp = deaths != -1
  mask_temp[:, :START_DAYS] = False
  observed_deaths_mask = tf.constant(mask_temp.T[np.newaxis, ...])

  conv_serial_interval = make_conv_serial_interval(initial_days, total_days)
  conv_fatality_rate = make_conv_fatality_rate(
      infection_fatality_rate, total_days)

  def likelihood_fn(tau, initial_cases, psi, kappa, mu, alpha_hier, ifr_noise):
    # Run models for infections and expected deaths
    predicted_infections = predict_infections(
        intervention_indicators, population, initial_cases, mu, alpha_hier,
        conv_serial_interval, initial_days, total_days)
    e_deaths_all_countries = predict_deaths(
        predicted_infections, ifr_noise, conv_fatality_rate)

    # Construct the Negative Binomial distribution for deaths by country.
    mu_m = tf.transpose(e_deaths_all_countries, [1, 0, 2])
    psi_m = psi[..., tf.newaxis, tf.newaxis]
    probs = tf.clip_by_value(mu_m / (mu_m + psi_m), 1e-9, 1.)
    likelihood_elementwise = tfd.NegativeBinomial(
        total_count=psi_m, probs=probs).log_prob(observed_deaths)
    return tf.reduce_sum(
        tf.where(observed_deaths_mask,
                likelihood_elementwise,
                tf.zeros_like(likelihood_elementwise)),
        axis=[-2, -1])

  return likelihood_fn

 

1.4死亡概率给出感染

本节计算感染后死亡的天的分布。它假定从感染到死亡的时间是两个γ变量数量的总和,代表从感染到发病的时间,从发病到死亡的时间。的时间-死亡分布与从感染病死率数据相结合的Verity等。 (2020年)来计算死亡的天之后感染的概率。

 def daily_fatality_probability(infection_fatality_rate, total_days):
  """Computes the probability of death `d` days after infection."""

  # Convert from alternative Gamma parametrization and construct distributions
  # for number of days from infection to onset and onset to death.
  concentration1 = tf.cast((1. / 0.86)**2, DTYPE)
  rate1 = concentration1 / 5.1
  concentration2 = tf.cast((1. / 0.45)**2, DTYPE)
  rate2 = concentration2 / 18.8
  infection_to_onset = tfd.Gamma(concentration=concentration1, rate=rate1)
  onset_to_death = tfd.Gamma(concentration=concentration2, rate=rate2)

  # Create empirical distribution for number of days from infection to death.
  inf_to_death_dist = tfd.Empirical(
      infection_to_onset.sample([5e6]) + onset_to_death.sample([5e6]))

  # Subtract the CDF value at day i from the value at day i + 1 to compute the
  # probability of death on day i given infection on day 0, and given that
  # death (not recovery) is the outcome.
  times = np.arange(total_days + 1., dtype=DTYPE) + 0.5
  cdf = inf_to_death_dist.cdf(times).numpy()
  f_before_ifr = cdf[1:] - cdf[:-1]
  # Explicitly set the zeroth value to the empirical cdf at time 1.5, to include
  # the mass between time 0 and time .5.
  f_before_ifr[0] = cdf[1]

  # Multiply the daily fatality rates conditional on infection and eventual
  # death (f_before_ifr) by the infection fatality rates (probability of death
  # given intection) to obtain the probability of death on day i conditional
  # on infection on day 0.
  return infection_fatality_rate[..., np.newaxis] * f_before_ifr

def make_conv_fatality_rate(infection_fatality_rate, total_days):
  """Computes the probability of death on day `i` given infection on day `j`."""
  p_fatal_all_countries = daily_fatality_probability(
      infection_fatality_rate, total_days)

  # Use the probability of death d days after infection in each country
  # to build an array of shape [total_days - 1, total_days, num_countries],
  # where the element [i, j, c] is the probability of death on day i+1 given
  # infection on day j in country c.
  conv_fatality_rate = np.zeros(
      [total_days - 1, total_days, p_fatal_all_countries.shape[0]])
  for n in range(1, total_days):
    conv_fatality_rate[n - 1, 0:n, :] = (
        p_fatal_all_countries[:, n - 1::-1]).T
  return tf.constant(conv_fatality_rate, dtype=DTYPE)

 

1.5串行区间

串行间隔是疾病传播的链连续情况之间的时间,并且被假定为伽玛分布。我们使用串行间隔分布来计算概率在感染一个人一天$ I $从以前感染一天附加$ J $(的人抓住了病毒conv_serial_interval参数predict_infections )。

 def make_conv_serial_interval(initial_days, total_days):
  """Construct the convolutional kernel for infection timing."""

  g = tfd.Gamma(tf.cast(1. / (0.62**2), DTYPE), 1./(6.5*0.62**2))
  g_cdf = g.cdf(np.arange(total_days, dtype=DTYPE))

  # Approximate the probability mass function for the number of days between
  # successive infections.
  serial_interval = g_cdf[1:] - g_cdf[:-1]

  # `conv_serial_interval` is an array of shape
  # [total_days - initial_days, total_days] in which entry [i, j] contains the
  # probability that an individual infected on day i + initial_days caught the
  # virus from someone infected on day j.
  conv_serial_interval = np.zeros([total_days - initial_days, total_days])
  for n in range(initial_days, total_days):
    conv_serial_interval[n - initial_days, 0:n] = serial_interval[n - 1::-1]
  return tf.constant(conv_serial_interval, dtype=DTYPE) 
 

2数据预处理

 COUNTRIES = [
    'Austria',
    'Belgium',
    'Denmark',
    'France',
    'Germany',
    'Italy',
    'Norway',
    'Spain',
    'Sweden',
    'Switzerland',
    'United_Kingdom'
]
 
 
raw_interventions = pd.read_csv(
    'https://raw.githubusercontent.com/ImperialCollegeLondon/covid19model/master/data/interventions.csv')

raw_interventions['Date effective'] = pd.to_datetime(
    raw_interventions['Date effective'], dayfirst=True)
interventions = raw_interventions.pivot(index='Country', columns='Type', values='Date effective')

# If any interventions happened after the lockdown, use the date of the lockdown.
for col in interventions.columns:
  idx = interventions[col] > interventions['Lockdown']
  interventions.loc[idx, col] = interventions[idx]['Lockdown']

num_countries = len(COUNTRIES)
 
 
# Load the case data
data = pd.read_csv('https://raw.githubusercontent.com/ImperialCollegeLondon/covid19model/master/data/COVID-19-up-to-date.csv')
# You can also use the dataset directly from european cdc (where the ICL model fetch their data from)
# data = pd.read_csv('https://opendata.ecdc.europa.eu/covid19/casedistribution/csv')

data['country'] = data['countriesAndTerritories']
data = data[['dateRep', 'cases', 'deaths', 'country']]
data = data[data['country'].isin(COUNTRIES)]
data['dateRep'] = pd.to_datetime(data['dateRep'], format='%d/%m/%Y')

# Add 0/1 features for whether or not each intevention was in place.
data = data.join(interventions, on='country', how='outer')
for col in interventions.columns:
  data[col] = (data['dateRep'] >= data[col]).astype(int)

# Add "any_intevention" 0/1 feature.
any_intervention_list = ['Schools + Universities',
                         'Self-isolating if ill',
                         'Public events',
                         'Lockdown',
                         'Social distancing encouraged']
data['any_intervention'] = (
    data[any_intervention_list].apply(np.sum, 'columns') > 0).astype(int)

# Index by country and date.
data = data.sort_values(by=['country', 'dateRep'])
data = data.set_index(['country', 'dateRep'])
 
 
infected_fatality_ratio = pd.read_csv(
    'https://raw.githubusercontent.com/ImperialCollegeLondon/covid19model/master/data/popt_ifr.csv')

infected_fatality_ratio = infected_fatality_ratio.replace(to_replace='United Kingdom', value='United_Kingdom')
infected_fatality_ratio['Country'] = infected_fatality_ratio.iloc[:, 1]
infected_fatality_ratio = infected_fatality_ratio[infected_fatality_ratio['Country'].isin(COUNTRIES)]
infected_fatality_ratio = infected_fatality_ratio[
  ['Country', 'popt', 'ifr']].set_index('Country')
infected_fatality_ratio = infected_fatality_ratio.sort_index()
infection_fatality_rate = infected_fatality_ratio['ifr'].to_numpy()
population_value = infected_fatality_ratio['popt'].to_numpy()
 

2.4预处理国家的具体数据

 # Model up to 75 days of data for each country, starting 30 days before the
# tenth cumulative death.
START_DAYS = 30
MAX_DAYS = 102
COVARIATE_COLUMNS = any_intervention_list + ['any_intervention']

# Initialize an array for number of deaths.
deaths = -np.ones((num_countries, MAX_DAYS), dtype=DTYPE)

# Assuming every intervention is still inplace in the unobserved future
num_interventions = len(COVARIATE_COLUMNS)
intervention_indicators = np.ones((num_countries, MAX_DAYS, num_interventions))

first_days = {}
for i, c in enumerate(COUNTRIES):
  c_data = data.loc[c]

  # Include data only after 10th death in a country.
  mask = c_data['deaths'].cumsum() >= 10

  # Get the date that the epidemic starts in a country.
  first_day = c_data.index[mask][0] - pd.to_timedelta(START_DAYS, 'days')
  c_data = c_data.truncate(before=first_day)

  # Truncate the data after 28 March 2020 for comparison with Flaxman et al.
  c_data = c_data.truncate(after='2020-03-28')

  c_data = c_data.iloc[:MAX_DAYS]
  days_of_data = c_data.shape[0]
  deaths[i, :days_of_data] = c_data['deaths']
  intervention_indicators[i, :days_of_data] = c_data[
    COVARIATE_COLUMNS].to_numpy()
  first_days[c] = first_day

# Number of sequential days to seed infections after the 10th death in a
# country. (N0 in authors' Stan code.)
INITIAL_DAYS = 6

# Number of days of observed data plus days to forecast. (N2 in authors' Stan
# code.)
TOTAL_DAYS = deaths.shape[1]
 

3模型推理

弗拉克斯曼等。 (2020)用于斯坦到样品从参数后具有Hamilton蒙特卡洛(HMC)和无-U-打开进样器(螺母)。

在这里,我们用双平均步长调整申请HMC。我们使用preconditioniting和初始化HMC的运行试验。

推理在运行在GPU上几分钟。

3.1构建之前和似然为模型

 jd_prior = make_jd_prior(num_countries, num_interventions)
likelihood_fn = make_likelihood_fn(
    intervention_indicators, population_value, deaths,
    infection_fatality_rate, INITIAL_DAYS, TOTAL_DAYS)
 

3.2公用事业

 def get_bijectors_from_samples(samples, unconstraining_bijectors, batch_axes):
  """Fit bijectors to the samples of a distribution.

  This fits a diagonal covariance multivariate Gaussian transformed by the
  `unconstraining_bijectors` to the provided samples. The resultant
  transformation can be used to precondition MCMC and other inference methods.
  """
  state_std = [    
      tf.math.reduce_std(bij.inverse(x), axis=batch_axes)
      for x, bij in zip(samples, unconstraining_bijectors)
  ]
  state_mu = [
      tf.math.reduce_mean(bij.inverse(x), axis=batch_axes)
      for x, bij in zip(samples, unconstraining_bijectors)
  ]
  return [tfb.Chain([cb, tfb.Shift(sh), tfb.Scale(sc)])
          for cb, sh, sc in zip(unconstraining_bijectors, state_mu, state_std)]

def generate_init_state_and_bijectors_from_prior(nchain, unconstraining_bijectors):
  """Creates an initial MCMC state, and bijectors from the prior."""
  prior_samples = jd_prior.sample(4096)

  bijectors = get_bijectors_from_samples(
      prior_samples, unconstraining_bijectors, batch_axes=0)
  
  init_state = [
    bij(tf.zeros([nchain] + list(s), DTYPE))
    for s, bij in zip(jd_prior.event_shape, bijectors)
  ]
  
  return init_state, bijectors
 
 @tf.function(autograph=False, experimental_compile=True)
def sample_hmc(
    init_state,
    step_size,
    target_log_prob_fn,
    unconstraining_bijectors,
    num_steps=500,
    burnin=50,
    num_leapfrog_steps=10):

    def trace_fn(_, pkr):
        return {
            'target_log_prob': pkr.inner_results.inner_results.accepted_results.target_log_prob,
            'diverging': ~(pkr.inner_results.inner_results.log_accept_ratio > -1000.),
            'is_accepted': pkr.inner_results.inner_results.is_accepted,
            'step_size': [tf.exp(s) for s in pkr.log_averaging_step],
        }
    
    hmc = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn,
        step_size=step_size,
        num_leapfrog_steps=num_leapfrog_steps)

    hmc = tfp.mcmc.TransformedTransitionKernel(
        inner_kernel=hmc,
        bijector=unconstraining_bijectors)
  
    hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(
        hmc,
        num_adaptation_steps=int(burnin * 0.8),
        target_accept_prob=0.8,
        decay_rate=0.5)

    # Sampling from the chain.
    return tfp.mcmc.sample_chain(
        num_results=burnin + num_steps,
        current_state=init_state,
        kernel=hmc,
        trace_fn=trace_fn)
 

3.3定义活动空间bijectors

从各向同性多元高斯分布采样时HMC是最有效的( Mangoubi和史密斯(2017) ),所以,第一步是预处理到目标密度看起来就像,作为可能的。

首先,我们把限制(例如,非负)变量不受约束的空间,这需要HMC。此外,我们采用SinhArcsinh bijector操纵转化靶密度的尾部的沉重;我们希望这些粗略脱落为$ E 1 { - X ^ 2} $。

 unconstraining_bijectors = [
    tfb.Chain([tfb.Scale(tf.constant(1 / 0.03, DTYPE)), tfb.Softplus(),
                tfb.SinhArcsinh(tailweight=tf.constant(1.85, DTYPE))]), # tau
    tfb.Chain([tfb.Scale(tf.constant(1 / 0.03, DTYPE)), tfb.Softplus(),
                tfb.SinhArcsinh(tailweight=tf.constant(1.85, DTYPE))]), # initial_cases
    tfb.Softplus(), # psi
    tfb.Softplus(), # kappa
    tfb.Softplus(), # mu
    tfb.Chain([tfb.Scale(tf.constant(0.4, DTYPE)), tfb.Softplus(),
                tfb.SinhArcsinh(skewness=tf.constant(-0.2, DTYPE), tailweight=tf.constant(2., DTYPE))]), # alpha
    tfb.Softplus(), # ifr_noise
]
 

3.4 HMC试运行

我们首先运行HMC预处理由之前,在变换空间从0的初始化。我们不使用前样本初始化链实际上这些往往会造成卡死链由于恶劣NUMERICS。

 %%time

nchain = 32

target_log_prob_fn = lambda *x: jd_prior.log_prob(*x) + likelihood_fn(*x)
init_state, bijectors = generate_init_state_and_bijectors_from_prior(nchain, unconstraining_bijectors)

# Each chain gets its own step size.
step_size = [tf.fill([nchain] + [1] * (len(s.shape) - 1), tf.constant(0.01, DTYPE)) for s in init_state]

burnin = 200
num_steps = 100

pilot_samples, pilot_sampler_stat = sample_hmc(
    init_state,
    step_size,
    target_log_prob_fn,
    bijectors,
    num_steps=num_steps,
    burnin=burnin,
    num_leapfrog_steps=10)
 
CPU times: user 56.8 s, sys: 2.34 s, total: 59.1 s
Wall time: 1min 1s

3.5可视化导频采样

我们正在寻找卡住链和目测收敛。我们可以在这里做正式的诊断,但考虑到这只是一个试运行,这不是超级必要的。

 import arviz as az
az.style.use('arviz-darkgrid')
 
 var_name = ['tau', 'initial_cases', 'psi', 'kappa', 'mu', 'alpha', 'ifr_noise']

pilot_with_warmup = {k: np.swapaxes(v.numpy(), 1, 0)
                     for k, v in zip(var_name, pilot_samples)}
 

我们热身期间观察到的分歧,主要是因为双平均步长调整为使用的最佳步长一个非常积极的搜索。一旦适应关闭,分歧消失为好。

 az_trace = az.from_dict(posterior=pilot_with_warmup,
                        sample_stats={'diverging': np.swapaxes(pilot_sampler_stat['diverging'].numpy(), 0, 1)})
az.plot_trace(az_trace, combined=True, compact=True, figsize=(12, 8));
 

PNG

 plt.plot(pilot_sampler_stat['step_size'][0]);
 

PNG

3.6运行HMC

原则上,我们可以使用归根到底导频采样(如果我们跑了它更长的时间来融合),但它是一个小更高效的启动另一个HMC来看,这一次预处理,并通过试点样本初始化。

 %%time

burnin = 50
num_steps = 200

bijectors = get_bijectors_from_samples([s[burnin:] for s in pilot_samples],
                                       unconstraining_bijectors=unconstraining_bijectors,
                                       batch_axes=(0, 1))

samples, sampler_stat = sample_hmc(
    [s[-1] for s in pilot_samples],
    [s[-1] for s in pilot_sampler_stat['step_size']],
    target_log_prob_fn,
    bijectors,
    num_steps=num_steps,
    burnin=burnin,
    num_leapfrog_steps=20)
 
CPU times: user 1min 26s, sys: 3.88 s, total: 1min 30s
Wall time: 1min 32s

 plt.plot(sampler_stat['step_size'][0]);
 

PNG

3.7可视化样本

 import arviz as az
az.style.use('arviz-darkgrid')
 
 var_name = ['tau', 'initial_cases', 'psi', 'kappa', 'mu', 'alpha', 'ifr_noise']

posterior = {k: np.swapaxes(v.numpy()[burnin:], 1, 0)
             for k, v in zip(var_name, samples)}
posterior_with_warmup = {k: np.swapaxes(v.numpy(), 1, 0)
             for k, v in zip(var_name, samples)}
 

计算链的总结。我们正在寻找高ESS和r_hat接近1。

 az.summary(posterior)
 
 az_trace = az.from_dict(posterior=posterior_with_warmup,
                        sample_stats={'diverging': np.swapaxes(sampler_stat['diverging'].numpy(), 0, 1)})
az.plot_trace(az_trace, combined=True, compact=True, figsize=(12, 8));
 

PNG

这是有益的所有维度看自相关函数。我们正在寻找其快速下降功能,但没有这么多,叫他们往负(其指示HMC创下了共鸣,这是不好的遍历性可以造成偏差)。

 with az.rc_context(rc={'plot.max_subplots': None}):
  az.plot_autocorr(posterior, combined=True, figsize=(12, 16), textsize=12);
 

PNG

4个结果

下面的图表分析超过$ R_t $后的预测分布,死亡人数和感染,类似于弗拉克斯曼等人分析的数量。 (2020年)。

 total_num_samples = np.prod(posterior['mu'].shape[:2])

# Calculate R_t given parameter estimates.
def rt_samples_batched(mu, intervention_indicators, alpha):
  linear_prediction = tf.reduce_sum(
      intervention_indicators * alpha[..., np.newaxis, np.newaxis, :], axis=-1)
  rt_hat = mu[..., tf.newaxis] * tf.exp(-linear_prediction, name='rt')
  return rt_hat

alpha_hat = tf.convert_to_tensor(
    posterior['alpha'].reshape(total_num_samples, posterior['alpha'].shape[-1]))
mu_hat = tf.convert_to_tensor(
    posterior['mu'].reshape(total_num_samples, num_countries))
rt_hat = rt_samples_batched(mu_hat, intervention_indicators, alpha_hat)
sampled_initial_cases = posterior['initial_cases'].reshape(
    total_num_samples, num_countries)
sampled_ifr_noise = posterior['ifr_noise'].reshape(
    total_num_samples, num_countries)
psi_hat = posterior['psi'].reshape([total_num_samples])

conv_serial_interval = make_conv_serial_interval(INITIAL_DAYS, TOTAL_DAYS)
conv_fatality_rate = make_conv_fatality_rate(infection_fatality_rate, TOTAL_DAYS)
pred_hat = predict_infections(
    intervention_indicators, population_value, sampled_initial_cases, mu_hat,
    alpha_hat, conv_serial_interval, INITIAL_DAYS, TOTAL_DAYS)
expected_deaths = predict_deaths(pred_hat, sampled_ifr_noise, conv_fatality_rate)

psi_m = psi_hat[np.newaxis, ..., np.newaxis]
probs = tf.clip_by_value(expected_deaths / (expected_deaths + psi_m), 1e-9, 1.)
predicted_deaths = tfd.NegativeBinomial(
    total_count=psi_m, probs=probs).sample()

 
 # Predict counterfactual infections/deaths in the absence of interventions
no_intervention_infections = predict_infections(
    intervention_indicators,
    population_value,
    sampled_initial_cases,
    mu_hat,
    tf.zeros_like(alpha_hat),
    conv_serial_interval,
    INITIAL_DAYS, TOTAL_DAYS)

no_intervention_expected_deaths = predict_deaths(
    no_intervention_infections, sampled_ifr_noise, conv_fatality_rate)
probs = tf.clip_by_value(
    no_intervention_expected_deaths / (no_intervention_expected_deaths + psi_m),
    1e-9, 1.)
no_intervention_predicted_deaths = tfd.NegativeBinomial(
    total_count=psi_m, probs=probs).sample()
 

干预措施的有效性4.1

类似于弗拉克斯曼等人的图4。 (2020年)。

 def intervention_effectiveness(alpha):

  alpha_adj = 1. - np.exp(-alpha + np.log(1.05) / 6.)
  alpha_adj_first = (

      1. - np.exp(-alpha - alpha[..., -1:] + np.log(1.05) / 6.))

  fig, ax = plt.subplots(1, 1, figsize=[12, 6])
  intervention_perm = [2, 1, 3, 4, 0]
  percentile_vals = [2.5, 97.5]
  jitter = .2

  for ind in range(5):
    first_low, first_high = tfp.stats.percentile(
        alpha_adj_first[..., ind], percentile_vals)
    low, high = tfp.stats.percentile(
        alpha_adj[..., ind], percentile_vals)

    p_ind = intervention_perm[ind]
    ax.hlines(p_ind, low, high, label='Later Intervention', colors='g')
    ax.scatter(alpha_adj[..., ind].mean(), p_ind, color='g')
    ax.hlines(p_ind + jitter, first_low, first_high,
              label='First Intervention', colors='r')
    ax.scatter(alpha_adj_first[..., ind].mean(), p_ind + jitter, color='r')

    if ind == 0:
      plt.legend(loc='lower right')
  ax.set_yticks(range(5))
  ax.set_yticklabels(
      [any_intervention_list[intervention_perm.index(p)] for p in range(5)])
  ax.set_xlim([-0.01, 1.])
  r = fig.patch
  r.set_facecolor('white') 

intervention_effectiveness(alpha_hat)
 

PNG

4.2感染,死亡人数,并通过国家R_t

类似于弗拉克斯曼等人的图2。 (2020年)。

 import matplotlib.dates as mdates

plot_quantile = True 
forecast_days = 0 

fig, ax = plt.subplots(11, 3, figsize=(15, 40))

for ind, country in enumerate(COUNTRIES):
  num_days = (pd.to_datetime('2020-03-28') - first_days[country]).days + forecast_days
  dates = [(first_days[country] + i*pd.to_timedelta(1, 'days')).strftime('%m-%d') for i in range(num_days)]
  plot_dates = [dates[i] for i in range(0, num_days, 7)]

  # Plot daily number of infections
  infections = pred_hat[:, :, ind]
  posterior_quantile = np.percentile(infections, [2.5, 25, 50, 75, 97.5], axis=-1)
  ax[ind, 0].plot(
      dates, posterior_quantile[2, :num_days],
      color='b', label='posterior median', lw=2)
  if plot_quantile:
    ax[ind, 0].fill_between(
        dates, posterior_quantile[1, :num_days], posterior_quantile[3, :num_days],
        color='b', label='50% quantile', alpha=.4)
    ax[ind, 0].fill_between(
        dates, posterior_quantile[0, :num_days], posterior_quantile[4, :num_days],
        color='b', label='95% quantile', alpha=.2)

  ax[ind, 0].set_xticks(plot_dates)
  ax[ind, 0].xaxis.set_tick_params(rotation=45)
  ax[ind, 0].set_ylabel('Daily number of infections', fontsize='large')
  ax[ind, 0].set_xlabel('Day', fontsize='large')

  # Plot deaths
  ax[ind, 1].set_title(country)

  samples = predicted_deaths[:, :, ind]
  posterior_quantile = np.percentile(samples, [2.5, 25, 50, 75, 97.5], axis=-1)
  ax[ind, 1].plot(
      range(num_days), posterior_quantile[2, :num_days],
      color='b', label='Posterior median', lw=2)
  if plot_quantile:
    ax[ind, 1].fill_between(
        range(num_days), posterior_quantile[1, :num_days], posterior_quantile[3, :num_days],
        color='b', label='50% quantile', alpha=.4)
    ax[ind, 1].fill_between(
        range(num_days), posterior_quantile[0, :num_days], posterior_quantile[4, :num_days],
        color='b', label='95% quantile', alpha=.2)

  observed = deaths[ind, :]
  observed[observed == -1] = np.nan
  ax[ind, 1].plot(
      dates, observed[:num_days],
      '--o', color='k', markersize=3,
      label='Observed deaths', alpha=.8)
  ax[ind, 1].set_xticks(plot_dates)
  ax[ind, 1].xaxis.set_tick_params(rotation=45)
  ax[ind, 1].set_title(country)
  ax[ind, 1].set_xlabel('Day', fontsize='large')
  ax[ind, 1].set_ylabel('Deaths', fontsize='large')

  # Plot R_t
  samples = np.transpose(rt_hat[:, ind, :])
  posterior_quantile = np.percentile(samples, [2.5, 25, 50, 75, 97.5], axis=-1)
  l1 = ax[ind, 2].plot(
      dates, posterior_quantile[2, :num_days],
      color='g', label='Posterior median', lw=2)
  l2 = ax[ind, 2].fill_between(
      dates, posterior_quantile[1, :num_days], posterior_quantile[3, :num_days],
      color='g', label='50% quantile', alpha=.4)
  if plot_quantile:
    l3 = ax[ind, 2].fill_between(
        dates, posterior_quantile[0, :num_days], posterior_quantile[4, :num_days],
        color='g', label='95% quantile', alpha=.2)

  l4 = ax[ind, 2].hlines(1., dates[0], dates[-1], linestyle='--', label='R == 1')
  ax[ind, 2].set_xlabel('Day', fontsize='large')
  ax[ind, 2].set_ylabel('R_t', fontsize='large')
  ax[ind, 2].set_xticks(plot_dates)
  ax[ind, 2].xaxis.set_tick_params(rotation=45)

fontsize = 'medium'
ax[0, 0].legend(loc='upper left', fontsize=fontsize)
ax[0, 1].legend(loc='upper left', fontsize=fontsize)
ax[0, 2].legend(
  bbox_to_anchor=(1., 1.),
  loc='upper right',
  borderaxespad=0.,
  fontsize=fontsize)

plt.tight_layout();
 

PNG

4.3与不干预预测/预测死亡人数每天

 plot_quantile = True 
forecast_days = 0 

fig, ax = plt.subplots(4, 3, figsize=(15, 16))
ax = ax.flatten()
fig.delaxes(ax[-1])
for country_index, country in enumerate(COUNTRIES):
  num_days = (pd.to_datetime('2020-03-28') - first_days[country]).days + forecast_days
  dates = [(first_days[country] + i*pd.to_timedelta(1, 'days')).strftime('%m-%d') for i in range(num_days)]
  plot_dates = [dates[i] for i in range(0, num_days, 7)]

  ax[country_index].set_title(country)

  quantile_vals = [.025, .25, .5, .75, .975]
  samples = predicted_deaths[:, :, country_index].numpy()
  quantiles = []

  psi_m = psi_hat[np.newaxis, ..., np.newaxis]
  probs = tf.clip_by_value(expected_deaths / (expected_deaths + psi_m), 1e-9, 1.)
  predicted_deaths_dist = tfd.NegativeBinomial(
    total_count=psi_m, probs=probs)

  posterior_quantile = np.percentile(samples, [2.5, 25, 50, 75, 97.5], axis=-1)
  ax[country_index].plot(
      dates, posterior_quantile[2, :num_days],
      color='b', label='Posterior median', lw=2)
  if plot_quantile:
    ax[country_index].fill_between(
        dates, posterior_quantile[1, :num_days], posterior_quantile[3, :num_days],
        color='b', label='50% quantile', alpha=.4)

  samples_counterfact = no_intervention_predicted_deaths[:, :, country_index]
  posterior_quantile = np.percentile(samples_counterfact, [2.5, 25, 50, 75, 97.5], axis=-1)
  ax[country_index].plot(
      dates, posterior_quantile[2, :num_days],
      color='r', label='Posterior median', lw=2)
  if plot_quantile:
    ax[country_index].fill_between(
        dates, posterior_quantile[1, :num_days], posterior_quantile[3, :num_days],
        color='r', label='50% quantile, no intervention', alpha=.4)

  observed = deaths[country_index, :]
  observed[observed == -1] = np.nan
  ax[country_index].plot(
      dates, observed[:num_days],
      '--o', color='k', markersize=3,
      label='Observed deaths', alpha=.8)
  ax[country_index].set_xticks(plot_dates)
  ax[country_index].xaxis.set_tick_params(rotation=45)
  ax[country_index].set_title(country)
  ax[country_index].set_xlabel('Day', fontsize='large')
  ax[country_index].set_ylabel('Deaths', fontsize='large')
  ax[0].legend(loc='upper left')
plt.tight_layout(pad=1.0);
 

PNG

麻省理工学院的许可证。

 

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.