Skip to content

Commit 2ec8007

Browse files
committed
working viral model
1 parent b83d244 commit 2ec8007

6 files changed

Lines changed: 228 additions & 32 deletions

File tree

poetry.lock

Lines changed: 83 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"tqdm>=4",
2323
"pyyaml (>=6.0.2,<7.0.0)",
2424
"seaborn (>=0.13.2,<0.14.0)",
25+
"scikit-learn (>=1.6.1,<2.0.0)",
2526
]
2627
[project.urls]
2728
Homepage = "https://github.com/DataForScience/epidemik"

src/epidemik/EpiModel.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def add_interaction(
6767
target: str,
6868
agent: str,
6969
rate: Union[None, float, str] = None,
70+
norm=True,
7071
**rates,
7172
) -> None:
7273
"""
@@ -79,7 +80,11 @@ def add_interaction(
7980
Name of the target compartment
8081
- agent: string
8182
Name of the agent
82-
- params: string
83+
- rate: float, str, None
84+
Rate of the interaction
85+
- norm: bool
86+
Whether to normalize the transition rate or not
87+
- rates:
8388
Named parameters for the interaction
8489
8590
Returns:
@@ -95,7 +100,10 @@ def add_interaction(
95100
rates = list(rates.keys())
96101
rate_key = rates[0]
97102

98-
self.transitions.add_edge(source, target, agent=agent, rate=rate_key)
103+
if agent not in self.transitions.nodes:
104+
self.transitions.add_node(agent)
105+
106+
self.transitions.add_edge(source, target, agent=agent, rate=rate_key, norm=norm)
99107

100108
def add_spontaneous(
101109
self, source: str, target: str, rate: Union[None, float, str] = None, **rates
@@ -162,14 +170,19 @@ def add_viral_generation(
162170
target_rate = rates[1]
163171

164172
self.transitions.add_edge(
165-
source, target, rate=source_rate
173+
source, target, rate=source_rate, viral_source=True, viral_target=False,
166174
)
167175
self.transitions.add_edge(
168-
source, target, rate=target_rate
176+
source, target, rate=target_rate, viral_source=False, viral_target=True
169177
)
170178

171179
def add_birth_rate(
172-
self, comps: Union[List, None] = None, rate: Union[float, None] = None, fixed=False, **rates
180+
self,
181+
comps: Union[List, None] = None,
182+
rate: Union[float, None] = None,
183+
fixed=False,
184+
global_rate=True,
185+
**rates
173186
) -> None:
174187
"""
175188
Add a birth rate to one or more compartments
@@ -200,9 +213,13 @@ def add_birth_rate(
200213

201214
self.transitions.nodes[comp]["birth"] = rate_key
202215
self.transitions.nodes[comp]["fixed"] = fixed
216+
self.transitions.nodes[comp]["global"] = global_rate
203217

204218
def add_death_rate(
205-
self, comps: Union[List, None] = None, rate: Union[None, float] = None, **rates
219+
self,
220+
comps: Union[List, None] = None,
221+
rate: Union[None, float] = None,
222+
**rates
206223
) -> None:
207224
"""
208225
Add a birth rate to one or more compartments
@@ -327,10 +344,10 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
327344
Internal function used by integration routine
328345
329346
Parameters:
330-
- population: numpy array
331-
Current population of each compartment
332347
- time: float
333348
Current time
349+
- population: numpy array
350+
Current population of each compartment
334351
- pos: dict
335352
Dictionary mapping compartment names to indices
336353
@@ -364,36 +381,45 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
364381

365382
if "agent" in trans:
366383
agent = trans["agent"]
384+
rate *= population[pos[agent]]
367385

368-
if self.population is None:
369-
rate *= population[pos[agent]] / N
370-
else:
371-
rate *= population[pos[agent]] / N[agent]
372-
386+
if trans["norm"]:
387+
if self.population is None:
388+
rate /= N
389+
else:
390+
rate /= N[agent]
391+
373392
if self.seasonality is not None:
374393
curr_t = int(time) % 365
375394
season = float(self.seasonality[curr_t])
376395
rate *= season
377396

378-
diff[pos[source]] -= rate
379-
diff[pos[target]] += rate
397+
if "viral_source" not in trans or trans["viral_source"]:
398+
diff[pos[source]] -= rate
399+
# Make sure viral generations are asymetric
400+
if "viral_target" not in trans or trans["viral_target"]:
401+
diff[pos[target]] += rate
380402

381-
# Population dynamics
382-
if self.demographics:
383-
for comp, data in self.transitions.nodes(data=True):
384-
comp_id = pos[comp]
403+
# Population dynamics
404+
if self.demographics:
405+
for comp, data in self.transitions.nodes(data=True):
406+
comp_id = pos[comp]
385407

386-
if "birth" in data:
387-
if "fixed" in data:
388-
births = self.params[data["birth"]]
408+
if "birth" in data:
409+
if "fixed" in data and data["fixed"]:
410+
births = self.params[data["birth"]]
411+
else:
412+
if data["global"]:
413+
total_population = population.sum()
414+
births = total_population * self.params[data["birth"]]
389415
else:
390416
births = population[comp_id] * self.params[data["birth"]]
391-
392-
diff[comp_id] += births
393417

394-
if "death" in data:
395-
deaths = population[comp_id] * self.params[data["death"]]
396-
diff[comp_id] -= deaths
418+
diff[comp_id] += births
419+
420+
if "death" in data:
421+
deaths = population[comp_id] * self.params[data["death"]]
422+
diff[comp_id] -= deaths
397423

398424
return diff
399425

@@ -616,7 +642,7 @@ def integrate(
616642

617643
population[pos[comp_age]] = n[i]
618644

619-
time = np.arange(t_min, t_min + timesteps, 1)
645+
time = np.arange(t_min, t_min + timesteps)
620646

621647
self.seasonality = seasonality
622648
values = pd.DataFrame(

tests/test_EpiModel.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
import unittest
22
from epidemik import EpiModel
3-
import logging
4-
3+
from sklearn.linear_model import LinearRegression
4+
import numpy as np
55

66
class EpiModelTestCase(unittest.TestCase):
77
def setUp(self):
8-
self.SIR = EpiModel()
8+
99
self.beta = 0.3
1010
self.mu = 0.1
11+
self.birth = 0.3
12+
self.death = 0.3
13+
self.fixed_birth = 10
14+
15+
self.SIR = EpiModel()
1116
self.SIR.add_interaction("S", "I", "I", beta=self.beta)
1217
self.SIR.add_spontaneous("I", "R", mu=self.mu)
1318

19+
self.birth_test = EpiModel()
20+
self.birth_test.add_interaction("S", "I", "I", beta=self.beta)
21+
self.birth_test.add_birth_rate("S", b=self.birth)
22+
23+
self.fixed_birth_test = EpiModel()
24+
self.fixed_birth_test.add_interaction("S", "I", "I", beta=self.beta)
25+
self.fixed_birth_test.add_birth_rate("S", b=self.fixed_birth, fixed=True, global_rate=False)
26+
27+
self.death_test = EpiModel()
28+
self.death_test.add_interaction("S", "I", "I", beta=self.beta)
29+
self.death_test.add_death_rate(d=self.death)
30+
1431
def test_R0(self):
1532
self.assertEqual(self.SIR.R0(), 3.0, "incorrect R0")
1633

@@ -32,3 +49,48 @@ def test_edges(self):
3249
self.assertEqual(self.SIR.params["mu"], self.mu)
3350
self.assertEqual(edge[0], "I")
3451
self.assertEqual(edge[1], "R")
52+
53+
def test_birth(self):
54+
self.assertEqual(self.birth_test.transitions.nodes['S']["birth"], "b")
55+
self.assertIn("b", self.birth_test.params.keys())
56+
self.assertEqual(self.birth_test.params["b"], self.birth)
57+
58+
def test_birth_rate(self):
59+
self.birth_test.integrate(10, S=990, I=10)
60+
values = self.birth_test.values_
61+
values['total'] = values.sum(axis=1)
62+
values = values.reset_index()
63+
64+
lm = LinearRegression()
65+
lm.fit(values['index'].values.reshape(-1, 1), np.log(values['total']))
66+
67+
self.assertAlmostEqual(lm.coef_[0], self.birth, delta=0.01)
68+
69+
def test_fixed_birth_rate(self):
70+
self.fixed_birth_test.integrate(10, S=990, I=10)
71+
values = self.fixed_birth_test.values_
72+
values['total'] = values.sum(axis=1)
73+
values = values.reset_index()
74+
75+
lm = LinearRegression()
76+
lm.fit(values['index'].values.reshape(-1, 1), values['total'])
77+
78+
self.assertAlmostEqual(lm.coef_[0], 10, delta=0.01)
79+
80+
def test_death(self):
81+
self.assertEqual(self.death_test.transitions.nodes['S']["death"], "d")
82+
self.assertEqual(self.death_test.transitions.nodes['I']["death"], "d")
83+
84+
self.assertIn("d", self.death_test.params.keys())
85+
self.assertEqual(self.death_test.params["d"], self.death)
86+
87+
def test_death_rate(self):
88+
self.death_test.integrate(10, S=990, I=10)
89+
values = self.death_test.values_
90+
values['total'] = values.sum(axis=1)
91+
values = values.reset_index()
92+
93+
lm = LinearRegression()
94+
lm.fit(values['index'].values.reshape(-1, 1), np.log(values['total']))
95+
96+
self.assertAlmostEqual(lm.coef_[0], -self.death, delta=0.01)

tests/test_MetaEpiModel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ def test_travel(self):
4646
def test_integrate(self):
4747
with self.assertRaises(NotImplementedError) as _:
4848
self.SIR.integrate()
49+
50+

tests/test_NetworkEpiModel.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
from epidemik import EpiModel, NetworkEpiModel
3+
import networkx as nx
4+
5+
class NetworkEpiModelTestCase(unittest.TestCase):
6+
def setUp(self):
7+
self.N = 300
8+
self.G_full = nx.erdos_renyi_graph(self.N, p=1.)
9+
self.beta = 0.05
10+
self.SI_full = NetworkEpiModel(self.G_full)
11+
12+
def test_named_parameters(self):
13+
self.SI_full.add_interaction("S", "I", "I", beta=self.beta)
14+
self.assertIn("beta",
15+
self.SI_full.params,
16+
"The parameter beta should be in the params dictionary")
17+
18+
def test_unnamed_parameters(self):
19+
self.SI_full.add_interaction("S", "I", "I", self.beta)
20+
self.assertIn("rate1",
21+
self.SI_full.params,
22+
"The parameter rate1 should be in the params dictionary")
23+

0 commit comments

Comments
 (0)