Skip to content

Commit

Permalink
fixing merge conflicts with Bens branch
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Jan 29, 2024
2 parents 98fa2bf + cf7e04b commit 24a52d9
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 33 deletions.
6 changes: 3 additions & 3 deletions R/hhs_hospitalization_formatter.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ dat_agegroup <- bind_rows(dat_pedia, dat_adult_grouped) |>
mutate(incidence = new_admission / population * 1e5) # per 100000

dat_agegroup_subset <- dat_agegroup |>
filter(date >= ymd("2022-02-13"), date <= ymd("2022-10-29")) |>
filter(date >= ymd("2022-02-20"), date <= ymd("2022-10-29")) |>
mutate(
week = epiweek(date)
) |>
Expand All @@ -127,7 +127,7 @@ dat_agegroup_weekly <- dat_agegroup_subset |>

# Plot
dat_agegroup |>
filter(date >= ymd("2022-02-12"), date <= ymd("2022-10-29")) |>
filter(date >= ymd("2022-02-20"), date <= ymd("2022-11-05")) |>
ggplot() +
geom_line(aes(x = date, y = incidence, colour = agegroup)) +
theme_bw()
Expand All @@ -144,4 +144,4 @@ dat_agegroup_weekly |>
# Output
dat_agegroup_subset |>
select(-population) |>
data.table::fwrite("./data/hospital_220213_220108.csv")
data.table::fwrite("./data/hospitalization-data/hospital_220220_221105.csv")
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@

This repository is for the design and implementation of a Scenarios forecasting model, built by the Scenarios team within CFA-Predict.

This code aims to combine a number of different codebases to forecast different covid scenarios with a Compartmental Mechanistic ODE model modeling multiple competing covid variants. The aim of this model is to provide enough flexibility for its users to explore a variety of scenarios, but also making certain design decisions that allow for fast computation and fitting as well as code readability.
Currently, we aim to use this code to forecast different disease tranmission scenarios with a compartmental mechanistic ODE model. This model is under development with transmission of SARS-CoV-2 as our primary focus. We plan to apply this model to the transmission of other respiratory viruses such as influenza and RSV. We aim to provide enough flexibility for the code users to explore a variety of scenarios, but also making certain design decisions that allow for fast computation and fitting as well as code readability.

[//]: # (This code aims to combine a number of different codebases to forecast different covid scenarios with a Compartmental Mechanistic ODE model modeling multiple competing covid variants. The aim of this model is to provide enough flexibility for its users to explore a variety of scenarios, but also making certain design decisions that allow for fast computation and fitting as well as code readability.)


What this model is:

a compartmental mechanistic ODE model capible of dynamic age binning, waning, vaccination scenarios, introduction of new variants, transmission structures, and timing estimation. TODO
a compartmental mechanistic ODE model that accounts for age structure, immunity history, vaccination, immunity waning and multiple variants.

[//]: # (capable of dynamic age binning, waning, vaccination scenarios, introduction of new variants, transmission structures, and timing estimation. TODO)

What this model is not:

Expand All @@ -33,11 +38,11 @@ In order to run this model and get basic results follow these steps:

Here is an example script of a basic run of 100 days without inference of parameters, saving the simulation as an image to output/example.png:
```
from model_odes.seir_model_v5 import seirw_ode
from model_odes.seip_model import seip_model
from mechanistic_compartments import build_basic_mechanistic_model
from config.config_base import ConfigBase
solution = build_basic_mechanistic_model(ConfigBase()).run(seirw_ode, tf=100.0, show=True, save=True, save_path="output/example.png")
solution = build_basic_mechanistic_model(ConfigBase()).run(seip_model, tf=100.0, show=True, save=True, save_path="output/example.png")
```

To create your own scenario, and modify parameters such as strain R0 and vaccination rate follow these steps:
Expand All @@ -49,19 +54,19 @@ To create your own scenario, and modify parameters such as strain R0 and vaccina
5. Run almost the same script as above, replacing your ConfigBase import with ConfigScenario.

```
from model_odes.seir_model_v5 import seirw_ode
from model_odes.seip_model import seip_model
from mechanistic_compartments import build_basic_mechanistic_model
from config.config_scenario_example import ConfigScenario
solution = build_basic_mechanistic_model(ConfigScenario()).run(seirw_ode, tf=100.0, show=True, save=True, save_path="output/example_scenario.png")
solution = build_basic_mechanistic_model(ConfigScenario()).run(seip_model, tf=100.0, show=True, save=True, save_path="output/example_scenario.png")
```

Before you go about running your own experiments it is best to understand how the model is initialized. Rather than looking through the model matricies yourself, the Scenarios team has created a Shiny application allowing for easy data visualization of the model's initial state!
Simply run `visualizer_app.py` and navigate to http://localhost:8000/ and play with the data yourself.

## Data Sources

The model is fed the following data sources:
The model (as described in the example script) is fed the following data sources:
1. data/demographic-data/contact_matricies : contact matricies sourced from work done by Dina Minstry's past work in this [Github Project](https://github.com/mobs-lab/mixing-patterns).
2. data/serological-data/* : serology data sourced from: [data.cdc.gov](https://data.cdc.gov/Laboratory-Surveillance/Nationwide-Commercial-Laboratory-Seroprevalence-Su/d2tw-32xv)
3. data/sim_data_*.sqlite: ABM data sourced from Tom Hladish's work found [here](https://github.com/tjhladish/covid-abm)
Expand All @@ -81,7 +86,7 @@ Thomas Hladish, Lead Data Scientist, [email protected], CDC/IOD/ORR/CFA

Ariel Shurygin, Data Scientist, [email protected], CDC/IOD/ORR/CFA

Ben Kok Toh, Data Scientist, [email protected], CDC/IOD/ORR/CFA
Kok Ben Toh, Data Scientist, [email protected], CDC/IOD/ORR/CFA

Michael Batista, Data Scientist, [email protected], CDC/IOD/ORR/CFA

Expand Down
198 changes: 198 additions & 0 deletions exp/fitting_us_20230213_20231029.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# %%
import copy

import jax.config
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import pandas as pd
from cycler import cycler
from inference import infer_model
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS

from config.config_base import ConfigBase
from mechanistic_compartments import build_basic_mechanistic_model
from model_odes.seip_model import seip_ode

# Use 4 cores
numpyro.set_host_device_count(5)
jax.config.update("jax_enable_x64", True)
# %%
# Observations
obs_df = pd.read_csv("./data/hospitalization-data/hospital_220220_221105.csv")
obs_incidence = obs_df.groupby(["date"])["new_admission_7"].apply(np.array)
obs_incidence = jnp.array(obs_incidence.tolist())
obs_incidence = obs_incidence[0:200,]

fig, ax = plt.subplots(1)
ax.plot(np.asarray(obs_incidence), label=["0-17", "18-49", "50-64", "65+"])
fig.legend()
ax.set_title("Observed data")
plt.show()

# %%
# Config to US population sizes
pop = 3.28e8
init_inf_prop = 0.03
intro_perc = 0.02
cb = ConfigBase(
POP_SIZE=pop,
INFECTIOUS_PERIOD=7.0,
INITIAL_INFECTIONS=init_inf_prop * pop,
INTRODUCTION_PERCENTAGE=intro_perc,
INTRODUCTION_SCALE=10,
NUM_WANING_COMPARTMENTS=5,
)
model = build_basic_mechanistic_model(cb)
model.VAX_EFF_MATRIX = jnp.array(
[
[0, 0.29, 0.58], # delta
[0, 0.24, 0.48], # omicron1
[0, 0.19, 0.38], # BA1.1
]
)


# %%
# MCMC specifications for "cold run"
nuts = NUTS(
infer_model,
dense_mass=True,
max_tree_depth=7,
init_strategy=numpyro.infer.init_to_median(),
target_accept_prob=0.80,
# find_heuristic_step_size=True,
)
mcmc = MCMC(
nuts,
num_warmup=500,
num_samples=500,
num_chains=5,
progress_bar=True,
)

# %%
# mcmc.warmup(
# rng_key=PRNGKey(8811967),
# collect_warmup=True,
# incidence=obs_incidence,
# model=model,
# )

mcmc.run(
rng_key=PRNGKey(8811968),
incidence=obs_incidence,
model=model,
)

# %%
samp = mcmc.get_samples(group_by_chain=True)
fig, axs = plt.subplots(2, 2)
axs[0, 0].set_title("Intro Time")
axs[0, 0].plot(np.transpose(samp["INTRO_TIME"]))
axs[0, 1].set_title("Intro Percentage")
axs[0, 1].plot(np.transpose(samp["INTRO_PERC"]))
axs[1, 0].set_title("R02")
axs[1, 0].plot(np.transpose(samp["r0_2"]))
axs[1, 1].set_title("R03")
axs[1, 1].plot(np.transpose(samp["r0_3"]), label=range(1, 6))
fig.legend()
plt.show()

# %%
# Take median of runs and check if fit is good
fitted_medians = {k: jnp.median(v[:, -1], axis=0) for k, v in samp.items()}
cb1 = ConfigBase(
POP_SIZE=pop,
INFECTIOUS_PERIOD=7.0,
INITIAL_INFECTIONS=fitted_medians["INITIAL_INFECTIONS"],
INTRODUCTION_PERCENTAGE=fitted_medians["INTRO_PERC"],
INTRODUCTION_TIMES=[fitted_medians["INTRO_TIME"]],
INTRODUCTION_SCALE=fitted_medians["INTRO_SCALE"],
NUM_WANING_COMPARTMENTS=5,
)
cb1.STRAIN_SPECIFIC_R0 = jnp.append(
jnp.array([1.2]),
jnp.append(fitted_medians["r0_2"], fitted_medians["r0_3"]),
)
model1 = build_basic_mechanistic_model(cb1)
imm = fitted_medians["imm_factor"]
ihr1 = fitted_medians["ihr"]
ihr_mult = fitted_medians["ihr_mult"]
model1.CROSSIMMUNITY_MATRIX = jnp.array(
[
[
0.0, # 000
1.0, # 001
1.0, # 010
1.0, # 011
1.0, # 100
1.0, # 101
1.0, # 110
1.0, # 111
],
[0.0, imm, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[
0.0,
imm**2,
imm,
imm,
1.0,
1.0,
1.0,
1.0,
],
]
)
model1.VAX_EFF_MATRIX = jnp.array(
[
[0, 0.29, 0.58], # delta
[0, 0.24, 0.48], # omicron1
[0, 0.19, 0.38], # BA1.1
]
)

solution1 = model1.run(
seip_ode,
tf=250,
show=True,
save=False,
# plot_commands=["S[:, 0, :, :]", "BA1.1", "omicron", "delta"],
log_scale=True,
)

model_incidence = jnp.sum(solution1.ys[3], axis=4)
model_incidence_0 = jnp.diff(model_incidence[:, :, 0, 0], axis=0)

model_incidence_1 = jnp.sum(model_incidence, axis=(2, 3))
model_incidence_1 = jnp.diff(model_incidence_1, axis=0)
model_incidence_1 -= model_incidence_0

sim_incidence = model_incidence_0 * ihr1 + model_incidence_1 * ihr1 * ihr_mult

colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]
fig, ax = plt.subplots(1)
ax.set_prop_cycle(cycler(color=colors))
ax.plot(sim_incidence, label=["0-17", "18-49", "50-64", "65+"])
ax.plot(
obs_incidence,
label=["0-17 (obs)", "18-49 (obs)", "50-64 (obs)", "65+ (obs)"],
linestyle="dashed",
)
fig.legend()
ax.set_title("Observed vs fitted")
plt.show()

# %%
# Check covariance matrix
samp_all = copy.deepcopy(mcmc.get_samples(group_by_chain=False))
ihr_dict = {
"ihr_" + str(i): v for i, v in enumerate(jnp.transpose(samp_all["ihr"]))
}
del samp_all["ihr"]
samp_all.update(ihr_dict)

samp_df = pd.DataFrame(samp_all)
samp_df.corr()
102 changes: 102 additions & 0 deletions exp/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import copy

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from mechanistic_compartments import BasicMechanisticModel
from model_odes.seip_model import seip_ode


def infer_model(incidence, model: BasicMechanisticModel):
m = copy.deepcopy(model)

# Parameters
r0_2_dist = dist.TransformedDistribution(
dist.Beta(50, 750), dist.transforms.AffineTransform(2.0, 8.0)
)
r0_3_dist = dist.TransformedDistribution(
dist.Beta(2, 14), dist.transforms.AffineTransform(2.0, 8.0)
)

r0_2 = numpyro.sample("r0_2", r0_2_dist)
# r0_2 = numpyro.deterministic("r0_2", 2.5)
r0_3 = numpyro.sample("r0_3", r0_3_dist)

introduction_time_dist = dist.TransformedDistribution(
dist.Beta(30, 70), dist.transforms.AffineTransform(0.0, 100)
)
introduction_scale_dist = dist.TransformedDistribution(
dist.Beta(50, 50), dist.transforms.AffineTransform(5.0, 10.0)
)
introduction_time = numpyro.sample("INTRO_TIME", introduction_time_dist)
introduction_perc = numpyro.sample("INTRO_PERC", dist.Beta(20, 980))
introduction_scale = numpyro.sample("INTRO_SCALE", introduction_scale_dist)

# Very correlated with R0_3 (might be better fixed than estimated)
imm = numpyro.sample("imm_factor", dist.Beta(700, 300))
# imm = numpyro.deterministic("imm_factor", 0.7)

m.STRAIN_SPECIFIC_R0 = jnp.array([1.2, r0_2, r0_3])
m.INTRODUCTION_TIMES_SAMPLE = [introduction_time]
m.INTRODUCTION_PERCENTAGE = introduction_perc
m.INTRODUCTION_SCALE = introduction_scale
m.CROSSIMMUNITY_MATRIX = jnp.array(
[
[
0.0, # 000
1.0, # 001
1.0, # 010
1.0, # 011
1.0, # 100
1.0, # 101
1.0, # 110
1.0, # 111
],
[0.0, imm, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[
0.0,
imm**2,
imm,
imm,
1.0,
1.0,
1.0,
1.0,
],
]
)

sol = m.run(
seip_ode,
tf=len(incidence),
sample_dist_dict={
"INITIAL_INFECTIONS": dist.TransformedDistribution(
dist.Beta(30, 970),
dist.transforms.AffineTransform(0.0, model.POP_SIZE),
)
},
)
model_incidence = jnp.sum(sol.ys[3], axis=4)
model_incidence_0 = jnp.diff(model_incidence[:, :, 0, 0], axis=0)

model_incidence_1 = jnp.sum(model_incidence, axis=(2, 3))
model_incidence_1 = jnp.diff(model_incidence_1, axis=0)
model_incidence_1 -= model_incidence_0

with numpyro.plate("num_age", 4):
ihr = numpyro.sample("ihr", dist.Beta(1, 9))

# IHR multiplier is very correlated with IHR (duh, might be better fixed)
# ihr_mult = numpyro.sample("ihr_mult", dist.Beta(100, 900))
ihr_mult = numpyro.deterministic("ihr_mult", 0.15)

sim_incidence = (
model_incidence_0 * ihr + model_incidence_1 * ihr * ihr_mult
)

numpyro.sample(
"incidence",
dist.Poisson(sim_incidence),
obs=incidence,
)
Loading

0 comments on commit 24a52d9

Please sign in to comment.