Skip to content

Commit

Permalink
Merge pull request #38 from cdcent/epoch-transitions
Browse files Browse the repository at this point in the history
Epoch transitions
  • Loading branch information
arik-shurygin authored Jan 29, 2024
2 parents cf7e04b + ac55d03 commit c0dc835
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 34 deletions.
23 changes: 13 additions & 10 deletions config/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(self, **kwargs) -> None:
self.HOSP_PATH = "data/hospital_220213_220108.csv"
# model initialization date DO NOT CHANGE
self.INIT_DATE = datetime.date(2022, 2, 11)
# if running epochs, this value will be number of days after the INIT_DATE the current epoch begins
# 0 if you are initializing.
self.DAYS_AFTER_INIT_DATE = 0
self.MINIMUM_AGE = 0
# age limits for each age bin in the model, begining with minimum age
# values are exclusive in upper bound. so [0,18) means 0-17, 18+
Expand Down Expand Up @@ -128,6 +131,15 @@ def __init__(self, **kwargs) -> None:
self.MCMC_PROGRESS_BAR = True
self.MODEL_RAND_SEED = 8675309

# this are all the strains currently supported, historical and future
self.all_strains_supported = [
"wildtype",
"alpha",
"delta",
"omicron",
"BA2/BA5",
]

# now update all parameters from kwargs, overriding the defaults if they are explicitly set
self.__dict__.update(kwargs)
self.GIT_HASH = (
Expand All @@ -154,17 +166,8 @@ def __init__(self, **kwargs) -> None:
["W" + str(idx) for idx in range(self.NUM_WANING_COMPARTMENTS)],
start=0,
)

# this are all the strains currently supported, historical and future
all_strains = [
"wildtype",
"alpha",
"delta",
"omicron",
"BA2/BA5",
]
# it often does not make sense to differentiate between wildtype and alpha, so combine strains here
self.STRAIN_NAMES = all_strains[5 - self.NUM_STRAINS :]
self.STRAIN_NAMES = self.all_strains_supported[-self.NUM_STRAINS :]
self.STRAIN_NAMES[0] = "pre-" + self.STRAIN_NAMES[1]
# in each compartment that is strain stratified we use strain indexes to improve readability.
# omicron will always be index=2 if num_strains >= 3. In a two strain model we must combine alpha and delta together.
Expand Down
56 changes: 56 additions & 0 deletions config/config_epoch_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import jax.numpy as jnp

from config.config_base import ConfigBase


class ConfigEpoch(ConfigBase):
"""
This is an example Config file for a particular scenario,
in which we want to test a 2 strain model with inital R0 of 1.5 for eachs train, and no vaccination.
Through inheritance this class will inherit all non-listed parameters from ConfigBase, and can even add its own!
"""

def __init__(self) -> None:
self.SCENARIO_NAME = "Epoch 2, omicron, BA2, XBB"
# set scenario parameters here
self.STRAIN_SPECIFIC_R0 = jnp.array([1.8, 3.0, 3.0]) # R0s
self.STRAIN_INTERACTIONS = jnp.array(
[
[1.0, 0.7, 0.49], # omicron
[0.7, 1.0, 0.7], # BA2
[0.49, 0.7, 1.0], # XBB
]
)
self.VAX_EFF_MATRIX = jnp.array(
[
[0, 0.34, 0.68], # omicron
[0, 0.24, 0.48], # BA2
[0, 0.14, 0.28], # XBB
]
)
self.all_strains_supported = [
"wildtype",
"alpha",
"delta",
"omicron",
"BA2/BA5",
"XBB1.5",
]
# specifies the number of days after the model INIT date this epoch occurs
self.DAYS_AFTER_INIT_DATE = 250
# DO NOT CHANGE THE FOLLOWING TWO LINES
super().__init__(**self.__dict__)
# Do not add any scenario parameters below, may create inconsistent state

def assert_valid_values(self):
"""
a function designed to be called after all parameters are initalized, does a series of reasonable checks
to ensure values are within expected ranges and no parameters directly contradict eachother.
Raises
----------
Assert Error:
if user supplies invalid parameters, short description will be provided as to why the parameter is wrong.
"""
super().assert_valid_values()
assert True, "any new parameters should be tested here"
93 changes: 86 additions & 7 deletions mechanistic_compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, **kwargs):

# GENERATE CROSS IMMUNITY MATRIX with protection from STRAIN_INTERACTIONS most recent infected strain.
if self.CROSSIMMUNITY_MATRIX is None:
self.build_cross_immunity_matrix()
self.load_cross_immunity_matrix()
# if not given, load population fractions based on observed census data into self
if not self.INITIAL_POPULATION_FRACTIONS:
self.load_initial_population_fractions()
Expand Down Expand Up @@ -303,7 +303,7 @@ def vaccination_rate(self, t):
"""
return jnp.exp(
utils.VAX_FUNCTION(
t,
t + self.DAYS_AFTER_INIT_DATE,
self.VAX_MODEL_KNOT_LOCATIONS,
self.VAX_MODEL_BASE_EQUATIONS,
self.VAX_MODEL_KNOTS,
Expand Down Expand Up @@ -528,6 +528,7 @@ def run(
max_steps=int(1e6),
)
self.solution = solution
self.solution_final_state = tuple(y[-1] for y in solution.ys)
save_path = (
save_path if save else None
) # dont set a save path if we dont want to save
Expand Down Expand Up @@ -610,6 +611,7 @@ def plot_diffrax_solution(
plot_labels: list[str] = None,
save_path: str = None,
log_scale: bool = None,
start_date: datetime.date = None,
fig: plt.figure = None,
ax: plt.axis = None,
):
Expand All @@ -635,6 +637,8 @@ def plot_diffrax_solution(
log_scale : bool, optional
whether or not to exclusively show the log or unlogged version of the plot, by default include both
in a stacked subplot.
start_date : date, optional
the start date of the x axis of the plot. Defaults to model.INIT_DATE + model.DAYS_AFTER_INIT_DATE
fig: matplotlib.pyplot.figure
if this plot is part of a larger subplots, pass the figure object here, otherwise one is created
ax: matplotlib.pyplot.axis
Expand All @@ -645,6 +649,11 @@ def plot_diffrax_solution(
fig, ax : matplotlib.Figure/axis object
objects containing the matplotlib figure and axis for further modifications if needed.
"""
# default start date is based on the model INIT date and in the case of epochs, days after initialization
if start_date is None:
start_date = self.INIT_DATE + datetime.timedelta(
days=self.DAYS_AFTER_INIT_DATE
)
plot_commands = [x.strip() for x in plot_commands]
if fig is None or ax is None:
fig, ax = plt.subplots(
Expand Down Expand Up @@ -680,11 +689,8 @@ def plot_diffrax_solution(
# if we explicitly set plot_labels, override the default ones.
label = plot_labels[idx] if plot_labels is not None else label
days = list(range(len(timeline)))
# incidence is aggregated weekly, so our array increases 7 days at a time
# if command == "incidence":
# days = [day * 7 for day in days]
x_axis = [
self.INIT_DATE + datetime.timedelta(days=day) for day in days
start_date + datetime.timedelta(days=day) for day in days
]
if command == "incidence":
# plot both logged and unlogged version by default
Expand Down Expand Up @@ -1162,12 +1168,13 @@ def load_init_infection_infected_and_exposed_dist_via_abm(self):
if self.INITIAL_INFECTIONS is None:
self.INITIAL_INFECTIONS = self.POP_SIZE * proportion_infected

def build_cross_immunity_matrix(self):
def load_cross_immunity_matrix(self):
"""
Loads the Crossimmunity matrix given the strain interactions matrix.
Strain interactions matrix is a matrix of shape (num_strains, num_strains) representing the relative immune escape risk
of those who are being challenged by a strain in dim 0 but have recovered from a strain in dim 1.
Neither the strain interactions matrix nor the crossimmunity matrix take into account waning.
Updates
----------
self.CROSSIMMUNITY_MATRIX:
Expand Down Expand Up @@ -1368,6 +1375,78 @@ def default(self, obj):
else: # if given empty file, just return JSON string
return json.dumps(self.config_file, indent=4, cls=CustomEncoder)

def collapse_strains(
self,
from_strain: str,
to_strain: str,
new_config: config,
):
"""
Modifies `self` such that all infections, infection histories, and enums that refer to the strain in `from_strain`
now point to `to_strain`. Number of strains are preserved, shifting all strain indexes left by 1
to make space for this new most-recent strain. New config is loaded to update strain specific values and indexes.
Example
----------
self.STRAIN_IDX["delta"] -> 0
self.STRAIN_IDX["omicron"] -> 1
self.STRAIN_IDX["BA2/BA5"] -> 2
self.collapse_strains("omicron", "delta") #collapses omicron and delta strains
self.STRAIN_IDX["delta"] -> 0
self.STRAIN_IDX["omicron"] -> *0*
self.STRAIN_IDX["BA2/BA5"] -> *1*
self.STRAIN_IDX[_] -> *2*
Parameters
----------
from_strain: str
the strain name of the strain being collapsed, whos references will be rerouted.
to_strain: str
the strain name of the strain being joined with from_strain, typically the oldest strain.
Modifies
----------
self.INITIAL_STATE
all compartments within initial state will be modified with new initial states
whos infection histories line up with the collapse strains and the new state.
all parameters within `new_config` will be used to override parameters within `self`
"""
from_strain_idx = self.STRAIN_IDX[from_strain]
to_strain_idx = self.STRAIN_IDX[to_strain]
(
immune_state_converter,
strain_converter,
) = utils.combined_strains_mapping(
from_strain_idx,
to_strain_idx,
self.NUM_STRAINS,
)
return_state = []
for idx, compartment in enumerate(self.INITIAL_STATE):
# we dont have a strain axis if are in the S compartment, otherwise we do
strain_axis = idx != self.IDX.S
strain_combined_compartment = utils.combine_strains(
compartment,
immune_state_converter,
strain_converter,
self.NUM_STRAINS,
strain_axis=strain_axis,
)
return_state.append(strain_combined_compartment)

# people who are actively infected with `from_strain` need to be combined together as well
self.INITIAL_STATE = tuple(return_state)
self.config_file = new_config
# use the new config to update things like STRAIN_IDX enum and strain_interactions matrix.
self.__dict__.update(**new_config.__dict__)
# end with some minor update tasks because our init_date likely changed
# along with our strain_interactions matrix
self.load_cross_immunity_matrix()
self.load_vaccination_model()
self.load_external_i_distributions()
self.load_contact_matrix()


def build_basic_mechanistic_model(config: config):
"""
Expand Down
16 changes: 0 additions & 16 deletions model_odes/seip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,6 @@ def seip_ode(state, t, parameters):
# slice across age, strain, and wane. vaccination updates the vax column and also moves all to w0.
# ex: diagonal movement from 1 shot in 4th waning compartment to 2 shots 0 waning compartment s[:, 0, 1, 3] -> s[:, 0, 2, 0]
vax_counts = s * p.VACCINATION_RATES(t)[:, jnp.newaxis, :, jnp.newaxis]

# for vaccine_count in range(p.MAX_VAX_COUNT + 1):
# # num of people who had vaccine_count shots and then are getting 1 more
# s_vax_count = vax_counts[:, :, vaccine_count, :]
# # people who just got vaccinated/recovered wont get another shot for at least 1 waning compartment time.
# s_vax_count = s_vax_count.at[:, :, 0].set(0)
# # sum all the people getting vaccines, across waning bins since they will be put in w0
# vax_gained = jnp.sum(s_vax_count, axis=(-1))
# # if people already at the max counted vaccinations, dont move them, only update waning
# if vaccine_count == p.MAX_VAX_COUNT:
# ds = ds.at[:, :, vaccine_count, 0].add(vax_gained)
# else: # increment num_vaccines by 1, waning reset
# ds = ds.at[:, :, vaccine_count + 1, 0].add(vax_gained)
# # we moved everyone into their correct compartment, now remove them from their starting position
# ds = ds.at[:, :, vaccine_count, :].add(-s_vax_count)

vax_counts = vax_counts.at[:, :, :, 0].set(0)
vax_gained = jnp.sum(vax_counts, axis=(-1))
ds = ds.at[:, :, p.MAX_VAX_COUNT, 0].add(vax_gained[:, :, p.MAX_VAX_COUNT])
Expand Down
Loading

0 comments on commit c0dc835

Please sign in to comment.