Modellare la diffusione del COVID-19 in Europa e l'effetto degli interventi

Concesso in licenza con la licenza MIT

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

Per rallentare la diffusione di COVID-19 all'inizio del 2020, i paesi europei hanno adottato interventi non farmaceutici come la chiusura di attività non essenziali, l'isolamento di singoli casi, divieti di viaggio e altre misure per incoraggiare il distanziamento sociale. L'Imperial College COVID-19 Response Team ha analizzato l'efficacia di queste misure nel loro articolo "La stima del numero di infezioni e l'impatto degli interventi non farmaceutici on COVID-19 in 11 paesi europei" , usando un modello gerarchico bayesiano combinato con un meccanicistica modello epidemiologico.

Questo Colab contiene un'implementazione TensorFlow Probability (TFP) di tale analisi, organizzata come segue:

  • La "configurazione del modello" definisce il modello epidemiologico per la trasmissione della malattia e i decessi risultanti, la distribuzione bayesiana precedente sui parametri del modello e la distribuzione del numero di decessi in base ai valori dei parametri.
  • La "preelaborazione dei dati" carica i dati sui tempi e il tipo di interventi in ciascun paese, i conteggi dei decessi nel tempo e i tassi di mortalità stimati per le persone infette.
  • "Model inference" costruisce un modello gerarchico bayesiano ed esegue Hamiltonian Monte Carlo (HMC) per campionare dalla distribuzione a posteriori sui parametri.
  • "Risultati" mostra distribuzioni predittive a posteriori per quantità di interesse come morti previste e morti controfattuali in assenza di interventi.

Il documento ha trovato prove che i paesi erano riusciti a ridurre il numero di nuove infezioni trasmesse da ogni persona infetta (\(R_t\)), ma che gli intervalli credibili conteneva \(R_t=1\) (il valore oltre il quale l'epidemia continua a diffondersi) e che era prematuro trarre conclusioni solide sull'efficacia degli interventi. Il codice di Stan per la carta è degli autori Github repository, e questo Colab riproduce versione 2 .

pip3 install -q git+git://github.com/arviz-devs/arviz.git
pip3 install -q tf-nightly tfp-nightly

Importazioni

1 Configurazione del modello

1.1 Modello meccanicistico per contagi e decessi

Il modello di infezione simula il numero di infezioni in ciascun paese nel tempo. I dati di input sono la tempistica e il tipo di interventi, la dimensione della popolazione e i casi iniziali. I parametri controllano l'efficacia degli interventi e il tasso di trasmissione della malattia. Il modello per il numero previsto di decessi applica un tasso di mortalità alle infezioni previste.

Il modello di infezione esegue una convoluzione delle precedenti infezioni giornaliere con la distribuzione dell'intervallo seriale (la distribuzione nel numero di giorni tra l'infezione e l'infezione di qualcun altro). Ad ogni passo, il numero di nuove infezioni al momento \(t\), \(n_t\), è calcolato come

\begin{equation} \sum_{i=0}^{t-1} n_i \mu_t \text{p} (\text{catturato da qualcuno infetto a } i | \text{appena infettato a } t) \end{ equation} dove \(\mu_t=R_t\) e la probabilità condizionale è memorizzato in conv_serial_interval , definito di seguito.

Il modello per le morti attese esegue una convoluzione delle infezioni giornaliere e la distribuzione dei giorni tra l'infezione e la morte. Cioè, i decessi attesi il giorno \(t\) è calcolato come

\ begin {equation} \ sum_ {i = 0} ^ {t-1} n_i \ text {p (la morte il giorno \(t\)| infezione giorno \(i\))} \ end {equation} dove è memorizzata la probabilità condizionata in conv_fatality_rate , definito di seguito.

from tensorflow_probability.python.internal import broadcast_util as bu

def predict_infections(
    intervention_indicators, population, initial_cases, mu, alpha_hier,
    conv_serial_interval, initial_days, total_days):
  """Predict the number of infections by forward-simulation.

  Args:
    intervention_indicators: Binary array of shape
      `[num_countries, total_days, num_interventions]`, in which `1` indicates
      the intervention is active in that country at that time and `0` indicates
      otherwise.
    population: Vector of length `num_countries`. Population of each country.
    initial_cases: Array of shape `[batch_size, num_countries]`. Number of cases
      in each country at the start of the simulation.
    mu: Array of shape `[batch_size, num_countries]`. Initial reproduction rate
      (R_0) by country.
    alpha_hier: Array of shape `[batch_size, num_interventions]` representing
      the effectiveness of interventions.
    conv_serial_interval: Array of shape
      `[total_days - initial_days, total_days]` output from
      `make_conv_serial_interval`. Convolution kernel for serial interval
      distribution.
    initial_days: Integer, number of sequential days to seed infections after
      the 10th death in a country. (N0 in the authors' Stan code.)
    total_days: Integer, number of days of observed data plus days to forecast.
      (N2 in the authors' Stan code.)
  Returns:
    predicted_infections: Array of shape
      `[total_days, batch_size, num_countries]`. (Batched) predicted number of
      infections over time and by country.
  """
  alpha = alpha_hier - tf.cast(np.log(1.05) / 6.0, DTYPE)

  # Multiply the effectiveness of each intervention in each country (alpha)
  # by the indicator variable for whether the intervention was active and sum
  # over interventions, yielding an array of shape
  # [total_days, batch_size, num_countries] that represents the total effectiveness of
  # all interventions in each country on each day (for a batch of data).
  linear_prediction = tf.einsum(
      'ijk,...k->j...i', intervention_indicators, alpha)

  # Adjust the reproduction rate per country downward, according to the
  # effectiveness of the interventions.
  rt = mu * tf.exp(-linear_prediction, name='reproduction_rate')

  # Initialize storage array for daily infections and seed it with initial
  # cases.
  daily_infections = tf.TensorArray(
      dtype=DTYPE, size=total_days, element_shape=initial_cases.shape)
  for i in range(initial_days):
    daily_infections = daily_infections.write(i, initial_cases)

  # Initialize cumulative cases.
  init_cumulative_infections = initial_cases * initial_days

  # Simulate forward for total_days days.
  cond = lambda i, *_: i < total_days
  def body(i, prev_daily_infections, prev_cumulative_infections):
    # The probability distribution over days j that someone infected on day i
    # caught the virus from someone infected on day j.
    p_infected_on_day = tf.gather(
        conv_serial_interval, i - initial_days, axis=0)

    # Multiply p_infected_on_day by the number previous infections each day and
    # by mu, and sum to obtain new infections on day i. Mu is adjusted by
    # the fraction of the population already infected, so that the population
    # size is the upper limit on the number of infections.
    prev_daily_infections_array = prev_daily_infections.stack()
    to_sum = prev_daily_infections_array * bu.left_justified_expand_dims_like(
        p_infected_on_day, prev_daily_infections_array)
    convolution = tf.reduce_sum(to_sum, axis=0)
    rt_adj = (
        (population - prev_cumulative_infections) / population
        ) * tf.gather(rt, i)
    new_infections = rt_adj * convolution

    # Update the prediction array and the cumulative number of infections.
    daily_infections = prev_daily_infections.write(i, new_infections)
    cumulative_infections = prev_cumulative_infections + new_infections
    return i + 1, daily_infections, cumulative_infections

  _, daily_infections_final, last_cumm_sum = tf.while_loop(
      cond, body,
      (initial_days, daily_infections, init_cumulative_infections),
      maximum_iterations=(total_days - initial_days))
  return daily_infections_final.stack()

def predict_deaths(predicted_infections, ifr_noise, conv_fatality_rate):
  """Expected number of reported deaths by country, by day.

  Args:
    predicted_infections: Array of shape
      `[total_days, batch_size, num_countries]` output from
      `predict_infections`.
    ifr_noise: Array of shape `[batch_size, num_countries]`. Noise in Infection
      Fatality Rate (IFR).
    conv_fatality_rate: Array of shape
      `[total_days - 1, total_days, num_countries]`. Convolutional kernel for
      calculating fatalities, output from `make_conv_fatality_rate`.
  Returns:
    predicted_deaths: Array of shape `[total_days, batch_size, num_countries]`.
      (Batched) predicted number of deaths over time and by country.
  """
  # Multiply the number of infections on day j by the probability of death
  # on day i given infection on day j, and sum over j. This yields the expected
  result_remainder = tf.einsum(
      'i...j,kij->k...j', predicted_infections, conv_fatality_rate) * ifr_noise

  # Concatenate the result with a vector of zeros so that the first day is
  # included.
  result_temp = 1e-15 * predicted_infections[:1]
  return tf.concat([result_temp, result_remainder], axis=0)

1.2 Priorità sui valori dei parametri

Qui definiamo la distribuzione a priori congiunta sui parametri del modello. Si presume che molti dei valori dei parametri siano indipendenti, in modo che il precedente possa essere espresso come:

\(\text p(\tau, y, \psi, \kappa, \mu, \alpha) = \text p(\tau)\text p(y|\tau)\text p(\psi)\text p(\kappa)\text p(\mu|\kappa)\text p(\alpha)\text p(\epsilon)\)

in quale:

  • \(\tau\) è il parametro tasso condivisa della distribuzione esponenziale rispetto al numero di casi iniziali per paese, \(y = y_1, ... y_{\text{num_countries} }\).
  • \(\psi\) è un parametro nella distribuzione binomiale negativa per numero di morti.
  • \(\kappa\) è il parametro di scala condivisa della distribuzione HalfNormal sopra il numero di riproduzione iniziale in ogni paese, \(\mu = \mu_1, ..., \mu_{\text{num_countries} }\) (che indica il numero di casi aggiuntivi trasmessi da ciascuna persona infetta).
  • \(\alpha = \alpha_1, ..., \alpha_6\) è l'efficacia di ciascuna delle sei interventi.
  • \(\epsilon\) (chiamato ifr_noise nel codice, dopo il codice Stan degli autori) è il rumore nel Infection Fatality Rate (IFR).

Esprimiamo questo modello come TFP JointDistribution, un tipo di distribuzione TFP che consente l'espressione di modelli grafici probabilistici.

def make_jd_prior(num_countries, num_interventions):
  return tfd.JointDistributionSequentialAutoBatched([
      # Rate parameter for the distribution of initial cases (tau).
      tfd.Exponential(rate=tf.cast(0.03, DTYPE)),

      # Initial cases for each country.
      lambda tau: tfd.Sample(
          tfd.Exponential(rate=tf.cast(1, DTYPE) / tau),
          sample_shape=num_countries),

      # Parameter in Negative Binomial model for deaths (psi).
      tfd.HalfNormal(scale=tf.cast(5, DTYPE)),

      # Parameter in the distribution over the initial reproduction number, R_0
      # (kappa).
      tfd.HalfNormal(scale=tf.cast(0.5, DTYPE)),

      # Initial reproduction number, R_0, for each country (mu).
      lambda kappa: tfd.Sample(
          tfd.TruncatedNormal(loc=3.28, scale=kappa, low=1e-5, high=1e5),
          sample_shape=num_countries),

      # Impact of interventions (alpha; shared for all countries).
      tfd.Sample(
          tfd.Gamma(tf.cast(0.1667, DTYPE), 1), sample_shape=num_interventions),

      # Multiplicative noise in Infection Fatality Rate.
      tfd.Sample(
          tfd.TruncatedNormal(
              loc=tf.cast(1., DTYPE), scale=0.1, low=1e-5, high=1e5),
              sample_shape=num_countries)
  ])

1.3 Probabilità di decessi osservati in base ai valori dei parametri

Le esprime modello probabilità \(p(\text{deaths} | \tau, y, \psi, \kappa, \mu, \alpha, \epsilon)\). Applica i modelli per il numero di infezioni e decessi attesi condizionati da parametri e assume che i decessi effettivi seguano una distribuzione binomiale negativa.

def make_likelihood_fn(
    intervention_indicators, population, deaths,
    infection_fatality_rate, initial_days, total_days):

  # Create a mask for the initial days of simulated data, as they are not
  # counted in the likelihood.
  observed_deaths = tf.constant(deaths.T[np.newaxis, ...], dtype=DTYPE)
  mask_temp = deaths != -1
  mask_temp[:, :START_DAYS] = False
  observed_deaths_mask = tf.constant(mask_temp.T[np.newaxis, ...])

  conv_serial_interval = make_conv_serial_interval(initial_days, total_days)
  conv_fatality_rate = make_conv_fatality_rate(
      infection_fatality_rate, total_days)

  def likelihood_fn(tau, initial_cases, psi, kappa, mu, alpha_hier, ifr_noise):
    # Run models for infections and expected deaths
    predicted_infections = predict_infections(
        intervention_indicators, population, initial_cases, mu, alpha_hier,
        conv_serial_interval, initial_days, total_days)
    e_deaths_all_countries = predict_deaths(
        predicted_infections, ifr_noise, conv_fatality_rate)

    # Construct the Negative Binomial distribution for deaths by country.
    mu_m = tf.transpose(e_deaths_all_countries, [1, 0, 2])
    psi_m = psi[..., tf.newaxis, tf.newaxis]
    probs = tf.clip_by_value(mu_m / (mu_m + psi_m), 1e-9, 1.)
    likelihood_elementwise = tfd.NegativeBinomial(
        total_count=psi_m, probs=probs).log_prob(observed_deaths)
    return tf.reduce_sum(
        tf.where(observed_deaths_mask,
                likelihood_elementwise,
                tf.zeros_like(likelihood_elementwise)),
        axis=[-2, -1])

  return likelihood_fn

1.4 Probabilità di morte per infezione

Questa sezione calcola la distribuzione dei decessi nei giorni successivi all'infezione. Presuppone che il tempo dall'infezione alla morte sia la somma di due quantità gamma-variate, che rappresentano il tempo dall'infezione all'insorgenza della malattia e il tempo dall'inizio alla morte. La distribuzione time-to-morte è combinato con infezione Fatality velocità di trasmissione dati da Verity et al. (2020) per calcolare la probabilità di morte nei giorni dopo l'infezione.

def daily_fatality_probability(infection_fatality_rate, total_days):
  """Computes the probability of death `d` days after infection."""

  # Convert from alternative Gamma parametrization and construct distributions
  # for number of days from infection to onset and onset to death.
  concentration1 = tf.cast((1. / 0.86)**2, DTYPE)
  rate1 = concentration1 / 5.1
  concentration2 = tf.cast((1. / 0.45)**2, DTYPE)
  rate2 = concentration2 / 18.8
  infection_to_onset = tfd.Gamma(concentration=concentration1, rate=rate1)
  onset_to_death = tfd.Gamma(concentration=concentration2, rate=rate2)

  # Create empirical distribution for number of days from infection to death.
  inf_to_death_dist = tfd.Empirical(
      infection_to_onset.sample([5e6]) + onset_to_death.sample([5e6]))

  # Subtract the CDF value at day i from the value at day i + 1 to compute the
  # probability of death on day i given infection on day 0, and given that
  # death (not recovery) is the outcome.
  times = np.arange(total_days + 1., dtype=DTYPE) + 0.5
  cdf = inf_to_death_dist.cdf(times).numpy()
  f_before_ifr = cdf[1:] - cdf[:-1]
  # Explicitly set the zeroth value to the empirical cdf at time 1.5, to include
  # the mass between time 0 and time .5.
  f_before_ifr[0] = cdf[1]

  # Multiply the daily fatality rates conditional on infection and eventual
  # death (f_before_ifr) by the infection fatality rates (probability of death
  # given intection) to obtain the probability of death on day i conditional
  # on infection on day 0.
  return infection_fatality_rate[..., np.newaxis] * f_before_ifr

def make_conv_fatality_rate(infection_fatality_rate, total_days):
  """Computes the probability of death on day `i` given infection on day `j`."""
  p_fatal_all_countries = daily_fatality_probability(
      infection_fatality_rate, total_days)

  # Use the probability of death d days after infection in each country
  # to build an array of shape [total_days - 1, total_days, num_countries],
  # where the element [i, j, c] is the probability of death on day i+1 given
  # infection on day j in country c.
  conv_fatality_rate = np.zeros(
      [total_days - 1, total_days, p_fatal_all_countries.shape[0]])
  for n in range(1, total_days):
    conv_fatality_rate[n - 1, 0:n, :] = (
        p_fatal_all_countries[:, n - 1::-1]).T
  return tf.constant(conv_fatality_rate, dtype=DTYPE)

1.5 Intervallo seriale

L'intervallo seriale è il tempo tra i casi successivi in ​​una catena di trasmissione della malattia e si presume che sia distribuito Gamma. Usiamo la distribuzione dell'intervallo di serie per calcolare la probabilità che una persona infetta il giorno \(i\) preso il virus da una persona precedentemente infettato il giorno \(j\) (il conv_serial_interval argomento per predict_infections ).

def make_conv_serial_interval(initial_days, total_days):
  """Construct the convolutional kernel for infection timing."""

  g = tfd.Gamma(tf.cast(1. / (0.62**2), DTYPE), 1./(6.5*0.62**2))
  g_cdf = g.cdf(np.arange(total_days, dtype=DTYPE))

  # Approximate the probability mass function for the number of days between
  # successive infections.
  serial_interval = g_cdf[1:] - g_cdf[:-1]

  # `conv_serial_interval` is an array of shape
  # [total_days - initial_days, total_days] in which entry [i, j] contains the
  # probability that an individual infected on day i + initial_days caught the
  # virus from someone infected on day j.
  conv_serial_interval = np.zeros([total_days - initial_days, total_days])
  for n in range(initial_days, total_days):
    conv_serial_interval[n - initial_days, 0:n] = serial_interval[n - 1::-1]
  return tf.constant(conv_serial_interval, dtype=DTYPE)

2 Pretrattamento dei dati

COUNTRIES = [
    'Austria',
    'Belgium',
    'Denmark',
    'France',
    'Germany',
    'Italy',
    'Norway',
    'Spain',
    'Sweden',
    'Switzerland',
    'United_Kingdom'
]

2.1 Recupero e pre-elaborazione dei dati degli interventi

2.2 Recuperare i dati del caso/decesso e partecipare agli interventi

2.3 Recupero ed elaborazione dell'indice di mortalità infetti e dei dati sulla popolazione

2.4 Pre-elaborazione dei dati specifici del paese

# Model up to 75 days of data for each country, starting 30 days before the
# tenth cumulative death.
START_DAYS = 30
MAX_DAYS = 102
COVARIATE_COLUMNS = any_intervention_list + ['any_intervention']

# Initialize an array for number of deaths.
deaths = -np.ones((num_countries, MAX_DAYS), dtype=DTYPE)

# Assuming every intervention is still inplace in the unobserved future
num_interventions = len(COVARIATE_COLUMNS)
intervention_indicators = np.ones((num_countries, MAX_DAYS, num_interventions))

first_days = {}
for i, c in enumerate(COUNTRIES):
  c_data = data.loc[c]

  # Include data only after 10th death in a country.
  mask = c_data['deaths'].cumsum() >= 10

  # Get the date that the epidemic starts in a country.
  first_day = c_data.index[mask][0] - pd.to_timedelta(START_DAYS, 'days')
  c_data = c_data.truncate(before=first_day)

  # Truncate the data after 28 March 2020 for comparison with Flaxman et al.
  c_data = c_data.truncate(after='2020-03-28')

  c_data = c_data.iloc[:MAX_DAYS]
  days_of_data = c_data.shape[0]
  deaths[i, :days_of_data] = c_data['deaths']
  intervention_indicators[i, :days_of_data] = c_data[
    COVARIATE_COLUMNS].to_numpy()
  first_days[c] = first_day

# Number of sequential days to seed infections after the 10th death in a
# country. (N0 in authors' Stan code.)
INITIAL_DAYS = 6

# Number of days of observed data plus days to forecast. (N2 in authors' Stan
# code.)
TOTAL_DAYS = deaths.shape[1]

3 Inferenza del modello

Flaxman et al. (2020) usata Stan a campione dal parametro posteriori con hamiltoniana Monte Carlo (HMC) e il No-U-Turn Sampler (NUTS).

Qui, applichiamo HMC con adattamento della dimensione del passo a doppia media. Utilizziamo un'esecuzione pilota di HMC per il precondizionamento e l'inizializzazione.

L'inferenza viene eseguita in pochi minuti su una GPU.

3.1 Costruire priorità e probabilità per il modello

jd_prior = make_jd_prior(num_countries, num_interventions)
likelihood_fn = make_likelihood_fn(
    intervention_indicators, population_value, deaths,
    infection_fatality_rate, INITIAL_DAYS, TOTAL_DAYS)

3.2 Utilità

def get_bijectors_from_samples(samples, unconstraining_bijectors, batch_axes):
  """Fit bijectors to the samples of a distribution.

  This fits a diagonal covariance multivariate Gaussian transformed by the
  `unconstraining_bijectors` to the provided samples. The resultant
  transformation can be used to precondition MCMC and other inference methods.
  """
  state_std = [    
      tf.math.reduce_std(bij.inverse(x), axis=batch_axes)
      for x, bij in zip(samples, unconstraining_bijectors)
  ]
  state_mu = [
      tf.math.reduce_mean(bij.inverse(x), axis=batch_axes)
      for x, bij in zip(samples, unconstraining_bijectors)
  ]
  return [tfb.Chain([cb, tfb.Shift(sh), tfb.Scale(sc)])
          for cb, sh, sc in zip(unconstraining_bijectors, state_mu, state_std)]

def generate_init_state_and_bijectors_from_prior(nchain, unconstraining_bijectors):
  """Creates an initial MCMC state, and bijectors from the prior."""
  prior_samples = jd_prior.sample(4096)

  bijectors = get_bijectors_from_samples(
      prior_samples, unconstraining_bijectors, batch_axes=0)

  init_state = [
    bij(tf.zeros([nchain] + list(s), DTYPE))
    for s, bij in zip(jd_prior.event_shape, bijectors)
  ]

  return init_state, bijectors
@tf.function(autograph=False, experimental_compile=True)
def sample_hmc(
    init_state,
    step_size,
    target_log_prob_fn,
    unconstraining_bijectors,
    num_steps=500,
    burnin=50,
    num_leapfrog_steps=10):

    def trace_fn(_, pkr):
        return {
            'target_log_prob': pkr.inner_results.inner_results.accepted_results.target_log_prob,
            'diverging': ~(pkr.inner_results.inner_results.log_accept_ratio > -1000.),
            'is_accepted': pkr.inner_results.inner_results.is_accepted,
            'step_size': [tf.exp(s) for s in pkr.log_averaging_step],
        }

    hmc = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn,
        step_size=step_size,
        num_leapfrog_steps=num_leapfrog_steps)

    hmc = tfp.mcmc.TransformedTransitionKernel(
        inner_kernel=hmc,
        bijector=unconstraining_bijectors)

    hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(
        hmc,
        num_adaptation_steps=int(burnin * 0.8),
        target_accept_prob=0.8,
        decay_rate=0.5)

    # Sampling from the chain.
    return tfp.mcmc.sample_chain(
        num_results=burnin + num_steps,
        current_state=init_state,
        kernel=hmc,
        trace_fn=trace_fn)

3.3 Definire i biiettori dello spazio degli eventi

HMC è più efficiente quando il campionamento da una distribuzione gaussiana multivariata isotropa ( Mangoubi & Smith (2017) ), quindi il primo passo è quello di pre-condizionare la densità obiettivo di guardare come molto simile a quella possibile.

Innanzitutto, trasformiamo le variabili vincolate (ad esempio, non negative) in uno spazio non vincolato, richiesto da HMC. Inoltre, utilizziamo il biiettore SinhArcsinh per manipolare la pesantezza delle code della densità target trasformata; vogliamo questi per cadere o meno come \(e^{-x^2}\).

unconstraining_bijectors = [
    tfb.Chain([tfb.Scale(tf.constant(1 / 0.03, DTYPE)), tfb.Softplus(),
                tfb.SinhArcsinh(tailweight=tf.constant(1.85, DTYPE))]), # tau
    tfb.Chain([tfb.Scale(tf.constant(1 / 0.03, DTYPE)), tfb.Softplus(),
                tfb.SinhArcsinh(tailweight=tf.constant(1.85, DTYPE))]), # initial_cases
    tfb.Softplus(), # psi
    tfb.Softplus(), # kappa
    tfb.Softplus(), # mu
    tfb.Chain([tfb.Scale(tf.constant(0.4, DTYPE)), tfb.Softplus(),
                tfb.SinhArcsinh(skewness=tf.constant(-0.2, DTYPE), tailweight=tf.constant(2., DTYPE))]), # alpha
    tfb.Softplus(), # ifr_noise
]

3.4 Esecuzione pilota HMC

Per prima cosa eseguiamo HMC precondizionato dal prior, inizializzato da 0 nello spazio trasformato. Non usiamo i campioni precedenti per inizializzare la catena poiché in pratica questi spesso si traducono in catene bloccate a causa di numeri scarsi.

%%time

nchain = 32

target_log_prob_fn = lambda *x: jd_prior.log_prob(*x) + likelihood_fn(*x)
init_state, bijectors = generate_init_state_and_bijectors_from_prior(nchain, unconstraining_bijectors)

# Each chain gets its own step size.
step_size = [tf.fill([nchain] + [1] * (len(s.shape) - 1), tf.constant(0.01, DTYPE)) for s in init_state]

burnin = 200
num_steps = 100

pilot_samples, pilot_sampler_stat = sample_hmc(
    init_state,
    step_size,
    target_log_prob_fn,
    bijectors,
    num_steps=num_steps,
    burnin=burnin,
    num_leapfrog_steps=10)
CPU times: user 56.8 s, sys: 2.34 s, total: 59.1 s
Wall time: 1min 1s

3.5 Visualizza campioni pilota

Stiamo cercando catene bloccate e convergenza oculare. Possiamo fare una diagnostica formale qui, ma non è super necessario dato che è solo una corsa pilota.

import arviz as az
az.style.use('arviz-darkgrid')
var_name = ['tau', 'initial_cases', 'psi', 'kappa', 'mu', 'alpha', 'ifr_noise']

pilot_with_warmup = {k: np.swapaxes(v.numpy(), 1, 0)
                     for k, v in zip(var_name, pilot_samples)}

Osserviamo divergenze durante il riscaldamento, principalmente perché l'adattamento della dimensione del passo a doppia media utilizza una ricerca molto aggressiva per la dimensione del passo ottimale. Una volta disattivato l'adattamento, scompaiono anche le divergenze.

az_trace = az.from_dict(posterior=pilot_with_warmup,
                        sample_stats={'diverging': np.swapaxes(pilot_sampler_stat['diverging'].numpy(), 0, 1)})
az.plot_trace(az_trace, combined=True, compact=True, figsize=(12, 8));

png

plt.plot(pilot_sampler_stat['step_size'][0]);

png

3.6 Esegui HMC

In linea di principio potremmo usare i campioni pilota per l'analisi finale (se li abbiamo eseguiti più a lungo per ottenere la convergenza), ma è un po' più efficiente avviare un'altra corsa HMC, questa volta precondizionata e inizializzata dai campioni pilota.

%%time

burnin = 50
num_steps = 200

bijectors = get_bijectors_from_samples([s[burnin:] for s in pilot_samples],
                                       unconstraining_bijectors=unconstraining_bijectors,
                                       batch_axes=(0, 1))

samples, sampler_stat = sample_hmc(
    [s[-1] for s in pilot_samples],
    [s[-1] for s in pilot_sampler_stat['step_size']],
    target_log_prob_fn,
    bijectors,
    num_steps=num_steps,
    burnin=burnin,
    num_leapfrog_steps=20)
CPU times: user 1min 26s, sys: 3.88 s, total: 1min 30s
Wall time: 1min 32s
plt.plot(sampler_stat['step_size'][0]);

png

3.7 Visualizza campioni

import arviz as az
az.style.use('arviz-darkgrid')
var_name = ['tau', 'initial_cases', 'psi', 'kappa', 'mu', 'alpha', 'ifr_noise']

posterior = {k: np.swapaxes(v.numpy()[burnin:], 1, 0)
             for k, v in zip(var_name, samples)}
posterior_with_warmup = {k: np.swapaxes(v.numpy(), 1, 0)
             for k, v in zip(var_name, samples)}

Calcola il riassunto delle catene. Cerchiamo ESS alto e r_hat vicino a 1.

az.summary(posterior)
az_trace = az.from_dict(posterior=posterior_with_warmup,
                        sample_stats={'diverging': np.swapaxes(sampler_stat['diverging'].numpy(), 0, 1)})
az.plot_trace(az_trace, combined=True, compact=True, figsize=(12, 8));

png

È istruttivo esaminare le funzioni di autocorrelazione in tutte le dimensioni. Stiamo cercando funzioni che scendano rapidamente, ma non così tanto da andare in negativo (che è indicativo di HMC che colpisce una risonanza, che è dannoso per l'ergodicità e può introdurre bias).

with az.rc_context(rc={'plot.max_subplots': None}):
  az.plot_autocorr(posterior, combined=True, figsize=(12, 16), textsize=12);

png

4 risultati

I seguenti trame analizzano le distribuzioni predittive posteriori oltre \(R_t\), numero di morti, e il numero di infezioni, simili all'analisi in Flaxman et al. (2020).

total_num_samples = np.prod(posterior['mu'].shape[:2])

# Calculate R_t given parameter estimates.
def rt_samples_batched(mu, intervention_indicators, alpha):
  linear_prediction = tf.reduce_sum(
      intervention_indicators * alpha[..., np.newaxis, np.newaxis, :], axis=-1)
  rt_hat = mu[..., tf.newaxis] * tf.exp(-linear_prediction, name='rt')
  return rt_hat

alpha_hat = tf.convert_to_tensor(
    posterior['alpha'].reshape(total_num_samples, posterior['alpha'].shape[-1]))
mu_hat = tf.convert_to_tensor(
    posterior['mu'].reshape(total_num_samples, num_countries))
rt_hat = rt_samples_batched(mu_hat, intervention_indicators, alpha_hat)
sampled_initial_cases = posterior['initial_cases'].reshape(
    total_num_samples, num_countries)
sampled_ifr_noise = posterior['ifr_noise'].reshape(
    total_num_samples, num_countries)
psi_hat = posterior['psi'].reshape([total_num_samples])

conv_serial_interval = make_conv_serial_interval(INITIAL_DAYS, TOTAL_DAYS)
conv_fatality_rate = make_conv_fatality_rate(infection_fatality_rate, TOTAL_DAYS)
pred_hat = predict_infections(
    intervention_indicators, population_value, sampled_initial_cases, mu_hat,
    alpha_hat, conv_serial_interval, INITIAL_DAYS, TOTAL_DAYS)
expected_deaths = predict_deaths(pred_hat, sampled_ifr_noise, conv_fatality_rate)

psi_m = psi_hat[np.newaxis, ..., np.newaxis]
probs = tf.clip_by_value(expected_deaths / (expected_deaths + psi_m), 1e-9, 1.)
predicted_deaths = tfd.NegativeBinomial(
    total_count=psi_m, probs=probs).sample()
# Predict counterfactual infections/deaths in the absence of interventions
no_intervention_infections = predict_infections(
    intervention_indicators,
    population_value,
    sampled_initial_cases,
    mu_hat,
    tf.zeros_like(alpha_hat),
    conv_serial_interval,
    INITIAL_DAYS, TOTAL_DAYS)

no_intervention_expected_deaths = predict_deaths(
    no_intervention_infections, sampled_ifr_noise, conv_fatality_rate)
probs = tf.clip_by_value(
    no_intervention_expected_deaths / (no_intervention_expected_deaths + psi_m),
    1e-9, 1.)
no_intervention_predicted_deaths = tfd.NegativeBinomial(
    total_count=psi_m, probs=probs).sample()

4.1 Efficacia degli interventi

Simile alla Figura 4 di Flaxman et al. (2020).

def intervention_effectiveness(alpha):

  alpha_adj = 1. - np.exp(-alpha + np.log(1.05) / 6.)
  alpha_adj_first = (
      1. - np.exp(-alpha - alpha[..., -1:] + np.log(1.05) / 6.))

  fig, ax = plt.subplots(1, 1, figsize=[12, 6])
  intervention_perm = [2, 1, 3, 4, 0]
  percentile_vals = [2.5, 97.5]
  jitter = .2

  for ind in range(5):
    first_low, first_high = tfp.stats.percentile(
        alpha_adj_first[..., ind], percentile_vals)
    low, high = tfp.stats.percentile(
        alpha_adj[..., ind], percentile_vals)

    p_ind = intervention_perm[ind]
    ax.hlines(p_ind, low, high, label='Later Intervention', colors='g')
    ax.scatter(alpha_adj[..., ind].mean(), p_ind, color='g')
    ax.hlines(p_ind + jitter, first_low, first_high,
              label='First Intervention', colors='r')
    ax.scatter(alpha_adj_first[..., ind].mean(), p_ind + jitter, color='r')

    if ind == 0:
      plt.legend(loc='lower right')
  ax.set_yticks(range(5))
  ax.set_yticklabels(
      [any_intervention_list[intervention_perm.index(p)] for p in range(5)])
  ax.set_xlim([-0.01, 1.])
  r = fig.patch
  r.set_facecolor('white') 

intervention_effectiveness(alpha_hat)

png

4.2 Infezioni, decessi e R_t per paese

Simile alla Figura 2 di Flaxman et al. (2020).

png

4.3 Numero giornaliero di morti previste/previste con e senza interventi

png