-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixing merge conflicts with Bens branch
- Loading branch information
Showing
6 changed files
with
338 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,11 +9,16 @@ | |
|
||
This repository is for the design and implementation of a Scenarios forecasting model, built by the Scenarios team within CFA-Predict. | ||
|
||
This code aims to combine a number of different codebases to forecast different covid scenarios with a Compartmental Mechanistic ODE model modeling multiple competing covid variants. The aim of this model is to provide enough flexibility for its users to explore a variety of scenarios, but also making certain design decisions that allow for fast computation and fitting as well as code readability. | ||
Currently, we aim to use this code to forecast different disease tranmission scenarios with a compartmental mechanistic ODE model. This model is under development with transmission of SARS-CoV-2 as our primary focus. We plan to apply this model to the transmission of other respiratory viruses such as influenza and RSV. We aim to provide enough flexibility for the code users to explore a variety of scenarios, but also making certain design decisions that allow for fast computation and fitting as well as code readability. | ||
|
||
[//]: # (This code aims to combine a number of different codebases to forecast different covid scenarios with a Compartmental Mechanistic ODE model modeling multiple competing covid variants. The aim of this model is to provide enough flexibility for its users to explore a variety of scenarios, but also making certain design decisions that allow for fast computation and fitting as well as code readability.) | ||
|
||
|
||
What this model is: | ||
|
||
a compartmental mechanistic ODE model capible of dynamic age binning, waning, vaccination scenarios, introduction of new variants, transmission structures, and timing estimation. TODO | ||
a compartmental mechanistic ODE model that accounts for age structure, immunity history, vaccination, immunity waning and multiple variants. | ||
|
||
[//]: # (capable of dynamic age binning, waning, vaccination scenarios, introduction of new variants, transmission structures, and timing estimation. TODO) | ||
|
||
What this model is not: | ||
|
||
|
@@ -33,11 +38,11 @@ In order to run this model and get basic results follow these steps: | |
|
||
Here is an example script of a basic run of 100 days without inference of parameters, saving the simulation as an image to output/example.png: | ||
``` | ||
from model_odes.seir_model_v5 import seirw_ode | ||
from model_odes.seip_model import seip_model | ||
from mechanistic_compartments import build_basic_mechanistic_model | ||
from config.config_base import ConfigBase | ||
solution = build_basic_mechanistic_model(ConfigBase()).run(seirw_ode, tf=100.0, show=True, save=True, save_path="output/example.png") | ||
solution = build_basic_mechanistic_model(ConfigBase()).run(seip_model, tf=100.0, show=True, save=True, save_path="output/example.png") | ||
``` | ||
|
||
To create your own scenario, and modify parameters such as strain R0 and vaccination rate follow these steps: | ||
|
@@ -49,19 +54,19 @@ To create your own scenario, and modify parameters such as strain R0 and vaccina | |
5. Run almost the same script as above, replacing your ConfigBase import with ConfigScenario. | ||
|
||
``` | ||
from model_odes.seir_model_v5 import seirw_ode | ||
from model_odes.seip_model import seip_model | ||
from mechanistic_compartments import build_basic_mechanistic_model | ||
from config.config_scenario_example import ConfigScenario | ||
solution = build_basic_mechanistic_model(ConfigScenario()).run(seirw_ode, tf=100.0, show=True, save=True, save_path="output/example_scenario.png") | ||
solution = build_basic_mechanistic_model(ConfigScenario()).run(seip_model, tf=100.0, show=True, save=True, save_path="output/example_scenario.png") | ||
``` | ||
|
||
Before you go about running your own experiments it is best to understand how the model is initialized. Rather than looking through the model matricies yourself, the Scenarios team has created a Shiny application allowing for easy data visualization of the model's initial state! | ||
Simply run `visualizer_app.py` and navigate to http://localhost:8000/ and play with the data yourself. | ||
|
||
## Data Sources | ||
|
||
The model is fed the following data sources: | ||
The model (as described in the example script) is fed the following data sources: | ||
1. data/demographic-data/contact_matricies : contact matricies sourced from work done by Dina Minstry's past work in this [Github Project](https://github.com/mobs-lab/mixing-patterns). | ||
2. data/serological-data/* : serology data sourced from: [data.cdc.gov](https://data.cdc.gov/Laboratory-Surveillance/Nationwide-Commercial-Laboratory-Seroprevalence-Su/d2tw-32xv) | ||
3. data/sim_data_*.sqlite: ABM data sourced from Tom Hladish's work found [here](https://github.com/tjhladish/covid-abm) | ||
|
@@ -81,7 +86,7 @@ Thomas Hladish, Lead Data Scientist, [email protected], CDC/IOD/ORR/CFA | |
|
||
Ariel Shurygin, Data Scientist, [email protected], CDC/IOD/ORR/CFA | ||
|
||
Ben Kok Toh, Data Scientist, [email protected], CDC/IOD/ORR/CFA | ||
Kok Ben Toh, Data Scientist, [email protected], CDC/IOD/ORR/CFA | ||
|
||
Michael Batista, Data Scientist, [email protected], CDC/IOD/ORR/CFA | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# %% | ||
import copy | ||
|
||
import jax.config | ||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import numpyro | ||
import pandas as pd | ||
from cycler import cycler | ||
from inference import infer_model | ||
from jax.random import PRNGKey | ||
from numpyro.infer import MCMC, NUTS | ||
|
||
from config.config_base import ConfigBase | ||
from mechanistic_compartments import build_basic_mechanistic_model | ||
from model_odes.seip_model import seip_ode | ||
|
||
# Use 4 cores | ||
numpyro.set_host_device_count(5) | ||
jax.config.update("jax_enable_x64", True) | ||
# %% | ||
# Observations | ||
obs_df = pd.read_csv("./data/hospitalization-data/hospital_220220_221105.csv") | ||
obs_incidence = obs_df.groupby(["date"])["new_admission_7"].apply(np.array) | ||
obs_incidence = jnp.array(obs_incidence.tolist()) | ||
obs_incidence = obs_incidence[0:200,] | ||
|
||
fig, ax = plt.subplots(1) | ||
ax.plot(np.asarray(obs_incidence), label=["0-17", "18-49", "50-64", "65+"]) | ||
fig.legend() | ||
ax.set_title("Observed data") | ||
plt.show() | ||
|
||
# %% | ||
# Config to US population sizes | ||
pop = 3.28e8 | ||
init_inf_prop = 0.03 | ||
intro_perc = 0.02 | ||
cb = ConfigBase( | ||
POP_SIZE=pop, | ||
INFECTIOUS_PERIOD=7.0, | ||
INITIAL_INFECTIONS=init_inf_prop * pop, | ||
INTRODUCTION_PERCENTAGE=intro_perc, | ||
INTRODUCTION_SCALE=10, | ||
NUM_WANING_COMPARTMENTS=5, | ||
) | ||
model = build_basic_mechanistic_model(cb) | ||
model.VAX_EFF_MATRIX = jnp.array( | ||
[ | ||
[0, 0.29, 0.58], # delta | ||
[0, 0.24, 0.48], # omicron1 | ||
[0, 0.19, 0.38], # BA1.1 | ||
] | ||
) | ||
|
||
|
||
# %% | ||
# MCMC specifications for "cold run" | ||
nuts = NUTS( | ||
infer_model, | ||
dense_mass=True, | ||
max_tree_depth=7, | ||
init_strategy=numpyro.infer.init_to_median(), | ||
target_accept_prob=0.80, | ||
# find_heuristic_step_size=True, | ||
) | ||
mcmc = MCMC( | ||
nuts, | ||
num_warmup=500, | ||
num_samples=500, | ||
num_chains=5, | ||
progress_bar=True, | ||
) | ||
|
||
# %% | ||
# mcmc.warmup( | ||
# rng_key=PRNGKey(8811967), | ||
# collect_warmup=True, | ||
# incidence=obs_incidence, | ||
# model=model, | ||
# ) | ||
|
||
mcmc.run( | ||
rng_key=PRNGKey(8811968), | ||
incidence=obs_incidence, | ||
model=model, | ||
) | ||
|
||
# %% | ||
samp = mcmc.get_samples(group_by_chain=True) | ||
fig, axs = plt.subplots(2, 2) | ||
axs[0, 0].set_title("Intro Time") | ||
axs[0, 0].plot(np.transpose(samp["INTRO_TIME"])) | ||
axs[0, 1].set_title("Intro Percentage") | ||
axs[0, 1].plot(np.transpose(samp["INTRO_PERC"])) | ||
axs[1, 0].set_title("R02") | ||
axs[1, 0].plot(np.transpose(samp["r0_2"])) | ||
axs[1, 1].set_title("R03") | ||
axs[1, 1].plot(np.transpose(samp["r0_3"]), label=range(1, 6)) | ||
fig.legend() | ||
plt.show() | ||
|
||
# %% | ||
# Take median of runs and check if fit is good | ||
fitted_medians = {k: jnp.median(v[:, -1], axis=0) for k, v in samp.items()} | ||
cb1 = ConfigBase( | ||
POP_SIZE=pop, | ||
INFECTIOUS_PERIOD=7.0, | ||
INITIAL_INFECTIONS=fitted_medians["INITIAL_INFECTIONS"], | ||
INTRODUCTION_PERCENTAGE=fitted_medians["INTRO_PERC"], | ||
INTRODUCTION_TIMES=[fitted_medians["INTRO_TIME"]], | ||
INTRODUCTION_SCALE=fitted_medians["INTRO_SCALE"], | ||
NUM_WANING_COMPARTMENTS=5, | ||
) | ||
cb1.STRAIN_SPECIFIC_R0 = jnp.append( | ||
jnp.array([1.2]), | ||
jnp.append(fitted_medians["r0_2"], fitted_medians["r0_3"]), | ||
) | ||
model1 = build_basic_mechanistic_model(cb1) | ||
imm = fitted_medians["imm_factor"] | ||
ihr1 = fitted_medians["ihr"] | ||
ihr_mult = fitted_medians["ihr_mult"] | ||
model1.CROSSIMMUNITY_MATRIX = jnp.array( | ||
[ | ||
[ | ||
0.0, # 000 | ||
1.0, # 001 | ||
1.0, # 010 | ||
1.0, # 011 | ||
1.0, # 100 | ||
1.0, # 101 | ||
1.0, # 110 | ||
1.0, # 111 | ||
], | ||
[0.0, imm, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], | ||
[ | ||
0.0, | ||
imm**2, | ||
imm, | ||
imm, | ||
1.0, | ||
1.0, | ||
1.0, | ||
1.0, | ||
], | ||
] | ||
) | ||
model1.VAX_EFF_MATRIX = jnp.array( | ||
[ | ||
[0, 0.29, 0.58], # delta | ||
[0, 0.24, 0.48], # omicron1 | ||
[0, 0.19, 0.38], # BA1.1 | ||
] | ||
) | ||
|
||
solution1 = model1.run( | ||
seip_ode, | ||
tf=250, | ||
show=True, | ||
save=False, | ||
# plot_commands=["S[:, 0, :, :]", "BA1.1", "omicron", "delta"], | ||
log_scale=True, | ||
) | ||
|
||
model_incidence = jnp.sum(solution1.ys[3], axis=4) | ||
model_incidence_0 = jnp.diff(model_incidence[:, :, 0, 0], axis=0) | ||
|
||
model_incidence_1 = jnp.sum(model_incidence, axis=(2, 3)) | ||
model_incidence_1 = jnp.diff(model_incidence_1, axis=0) | ||
model_incidence_1 -= model_incidence_0 | ||
|
||
sim_incidence = model_incidence_0 * ihr1 + model_incidence_1 * ihr1 * ihr_mult | ||
|
||
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"] | ||
fig, ax = plt.subplots(1) | ||
ax.set_prop_cycle(cycler(color=colors)) | ||
ax.plot(sim_incidence, label=["0-17", "18-49", "50-64", "65+"]) | ||
ax.plot( | ||
obs_incidence, | ||
label=["0-17 (obs)", "18-49 (obs)", "50-64 (obs)", "65+ (obs)"], | ||
linestyle="dashed", | ||
) | ||
fig.legend() | ||
ax.set_title("Observed vs fitted") | ||
plt.show() | ||
|
||
# %% | ||
# Check covariance matrix | ||
samp_all = copy.deepcopy(mcmc.get_samples(group_by_chain=False)) | ||
ihr_dict = { | ||
"ihr_" + str(i): v for i, v in enumerate(jnp.transpose(samp_all["ihr"])) | ||
} | ||
del samp_all["ihr"] | ||
samp_all.update(ihr_dict) | ||
|
||
samp_df = pd.DataFrame(samp_all) | ||
samp_df.corr() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import copy | ||
|
||
import jax.numpy as jnp | ||
import numpyro | ||
import numpyro.distributions as dist | ||
|
||
from mechanistic_compartments import BasicMechanisticModel | ||
from model_odes.seip_model import seip_ode | ||
|
||
|
||
def infer_model(incidence, model: BasicMechanisticModel): | ||
m = copy.deepcopy(model) | ||
|
||
# Parameters | ||
r0_2_dist = dist.TransformedDistribution( | ||
dist.Beta(50, 750), dist.transforms.AffineTransform(2.0, 8.0) | ||
) | ||
r0_3_dist = dist.TransformedDistribution( | ||
dist.Beta(2, 14), dist.transforms.AffineTransform(2.0, 8.0) | ||
) | ||
|
||
r0_2 = numpyro.sample("r0_2", r0_2_dist) | ||
# r0_2 = numpyro.deterministic("r0_2", 2.5) | ||
r0_3 = numpyro.sample("r0_3", r0_3_dist) | ||
|
||
introduction_time_dist = dist.TransformedDistribution( | ||
dist.Beta(30, 70), dist.transforms.AffineTransform(0.0, 100) | ||
) | ||
introduction_scale_dist = dist.TransformedDistribution( | ||
dist.Beta(50, 50), dist.transforms.AffineTransform(5.0, 10.0) | ||
) | ||
introduction_time = numpyro.sample("INTRO_TIME", introduction_time_dist) | ||
introduction_perc = numpyro.sample("INTRO_PERC", dist.Beta(20, 980)) | ||
introduction_scale = numpyro.sample("INTRO_SCALE", introduction_scale_dist) | ||
|
||
# Very correlated with R0_3 (might be better fixed than estimated) | ||
imm = numpyro.sample("imm_factor", dist.Beta(700, 300)) | ||
# imm = numpyro.deterministic("imm_factor", 0.7) | ||
|
||
m.STRAIN_SPECIFIC_R0 = jnp.array([1.2, r0_2, r0_3]) | ||
m.INTRODUCTION_TIMES_SAMPLE = [introduction_time] | ||
m.INTRODUCTION_PERCENTAGE = introduction_perc | ||
m.INTRODUCTION_SCALE = introduction_scale | ||
m.CROSSIMMUNITY_MATRIX = jnp.array( | ||
[ | ||
[ | ||
0.0, # 000 | ||
1.0, # 001 | ||
1.0, # 010 | ||
1.0, # 011 | ||
1.0, # 100 | ||
1.0, # 101 | ||
1.0, # 110 | ||
1.0, # 111 | ||
], | ||
[0.0, imm, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], | ||
[ | ||
0.0, | ||
imm**2, | ||
imm, | ||
imm, | ||
1.0, | ||
1.0, | ||
1.0, | ||
1.0, | ||
], | ||
] | ||
) | ||
|
||
sol = m.run( | ||
seip_ode, | ||
tf=len(incidence), | ||
sample_dist_dict={ | ||
"INITIAL_INFECTIONS": dist.TransformedDistribution( | ||
dist.Beta(30, 970), | ||
dist.transforms.AffineTransform(0.0, model.POP_SIZE), | ||
) | ||
}, | ||
) | ||
model_incidence = jnp.sum(sol.ys[3], axis=4) | ||
model_incidence_0 = jnp.diff(model_incidence[:, :, 0, 0], axis=0) | ||
|
||
model_incidence_1 = jnp.sum(model_incidence, axis=(2, 3)) | ||
model_incidence_1 = jnp.diff(model_incidence_1, axis=0) | ||
model_incidence_1 -= model_incidence_0 | ||
|
||
with numpyro.plate("num_age", 4): | ||
ihr = numpyro.sample("ihr", dist.Beta(1, 9)) | ||
|
||
# IHR multiplier is very correlated with IHR (duh, might be better fixed) | ||
# ihr_mult = numpyro.sample("ihr_mult", dist.Beta(100, 900)) | ||
ihr_mult = numpyro.deterministic("ihr_mult", 0.15) | ||
|
||
sim_incidence = ( | ||
model_incidence_0 * ihr + model_incidence_1 * ihr * ihr_mult | ||
) | ||
|
||
numpyro.sample( | ||
"incidence", | ||
dist.Poisson(sim_incidence), | ||
obs=incidence, | ||
) |
Oops, something went wrong.