Skip to content

Commit 43df4f7

Browse files
committed
factor out Parameters class
1 parent 663ad79 commit 43df4f7

2 files changed

Lines changed: 105 additions & 68 deletions

File tree

src/epidemik/EpiModel.py

Lines changed: 48 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import networkx as nx
1616
import numpy as np
1717
from numpy import linalg
18-
from numpy import random
1918
import scipy.integrate
2019
import pandas as pd
2120
import matplotlib.pyplot as plt
@@ -49,7 +48,7 @@ def __init__(self, compartments=None, seed=None, rng=None):
4948
self.population = None
5049
self.orig_comps = None
5150
self.demographics = False
52-
self.params = {}
51+
self.params = utils.Parameters()
5352

5453
if seed is None:
5554
seed = int(time.time()) + os.getpid()
@@ -62,24 +61,6 @@ def __init__(self, compartments=None, seed=None, rng=None):
6261
if compartments is not None:
6362
self.transitions.add_nodes_from([comp for comp in compartments])
6463

65-
def _compute_rate(self, rate: Union[str, float]) -> float:
66-
"""
67-
Compute the rate from a string
68-
69-
Parameters:
70-
- rate: string
71-
Rate of the transition
72-
73-
Returns:
74-
float
75-
The computed rate
76-
"""
77-
if rate in self.params:
78-
if isinstance(self.params[rate], float):
79-
return self.params[rate]
80-
else:
81-
return eval(self.params[rate], {}, self.params)
82-
8364
def add_interaction(
8465
self,
8566
source: str,
@@ -108,9 +89,9 @@ def add_interaction(
10889
if rate is not None:
10990
count = len(self.params) + 1
11091
rate_key = "rate" + str(count)
111-
self.define_parameters(rate_key=rate)
92+
self.params[rate_key]=rate
11293
else:
113-
self.define_parameters(**rates)
94+
self.params.define_parameters(**rates)
11495
rates = list(rates.keys())
11596
rate_key = rates[0]
11697

@@ -137,21 +118,22 @@ def add_spontaneous(
137118
if rate is not None:
138119
count = len(self.params) + 1
139120
rate_key = "rate" + str(count)
140-
self.define_parameters(rate_key=rate)
121+
self.params[rate_key, rate]
141122
else:
142-
self.define_parameters(**rates)
123+
self.params.define_parameters(**rates)
143124
rates = list(rates.keys())
144125
rate_key = rates[0]
145126

146127
self.transitions.add_edge(source, target, rate=rate_key)
147128

148-
def add_viral_generation(self,
149-
source:str,
150-
target:str,
151-
source_rate:Union[float, str, None] = None,
152-
target_rate:Union[float, str, None] = None,
153-
**rates
154-
) -> None:
129+
def add_viral_generation(
130+
self,
131+
source: str,
132+
target: str,
133+
source_rate: Union[float, str, None] = None,
134+
target_rate: Union[float, str, None] = None,
135+
**rates,
136+
) -> None:
155137
"""
156138
Add a viral generation transition
157139
@@ -169,20 +151,25 @@ def add_viral_generation(self,
169151
if source_rate is not None and target_rate is not None:
170152
count = len(self.params) + 1
171153
rate_key = "rate" + str(count)
172-
self.define_parameters(rate_key = source_rate)
154+
self.params[rate_key] = source_rate
173155

174-
rate_key = "rate" + str(count+1)
175-
self.define_parameters(**{rate_key: target_rate})
176-
else:
177-
self.define_parameters(**rates)
156+
rate_key = "rate" + str(count + 1)
157+
self.params[rate_key] = target_rate
158+
else:
159+
self.params.define_parameters(**rates)
178160
rates = list(rates.keys())
179161
source_rate = rates[0]
180162
target_rate = rates[1]
181163

182-
self.transitions.add_edge(source, target, source_rate=source_rate, target_rate=target_rate)
164+
self.transitions.add_edge(
165+
source, target, rate=source_rate
166+
)
167+
self.transitions.add_edge(
168+
source, target, rate=target_rate
169+
)
183170

184171
def add_birth_rate(
185-
self, comps: Union[List, None] = None, rate: Union[float, None] = None, **rates
172+
self, comps: Union[List, None] = None, rate: Union[float, None] = None, fixed=False, **rates
186173
) -> None:
187174
"""
188175
Add a birth rate to one or more compartments
@@ -212,6 +199,7 @@ def add_birth_rate(
212199
self.transitions.add_node(comp)
213200

214201
self.transitions.nodes[comp]["birth"] = rate_key
202+
self.transitions.nodes[comp]["fixed"] = fixed
215203

216204
def add_death_rate(
217205
self, comps: Union[List, None] = None, rate: Union[None, float] = None, **rates
@@ -246,7 +234,12 @@ def add_death_rate(
246234
self.transitions.nodes[comp]["death"] = rate_key
247235

248236
def add_vaccination(
249-
self, source: str, target: str, start: int, rate: Union[None, float, str], **rates
237+
self,
238+
source: str,
239+
target: str,
240+
start: int,
241+
rate: Union[None, float, str],
242+
**rates,
250243
) -> None:
251244
"""
252245
Add a vaccination transition between two compartments
@@ -275,27 +268,6 @@ def add_vaccination(
275268

276269
self.transitions.add_edge(source, target, rate=rate_key, start=start)
277270

278-
def define_parameters(self, **kwargs) -> None:
279-
"""
280-
Define one or more parameter for the model
281-
282-
Parameters:
283-
- kwargs: keyword arguments
284-
Named parameters for the model
285-
286-
Returns:
287-
None
288-
"""
289-
for key, value in kwargs.items():
290-
if isinstance(value, str):
291-
try:
292-
# Convert floats written as strings to float
293-
value = float(value.strip())
294-
except:
295-
pass
296-
297-
self.params[key] = value
298-
299271
def add_age_structure(self, matrix: List, population: List) -> List[List]:
300272
"""
301273
Add a vaccination transition between two compartments
@@ -384,7 +356,7 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
384356
target = edge[1]
385357
trans = edge[2]
386358

387-
rate_val = self._compute_rate(trans["rate"]) #self.params[trans["rate"]]
359+
rate_val = self.params[trans["rate"]]
388360
rate = rate_val * population[pos[source]]
389361

390362
if "start" in trans and trans["start"] >= time:
@@ -412,11 +384,15 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
412384
comp_id = pos[comp]
413385

414386
if "birth" in data:
415-
births = population[comp_id] * data["birth"]
387+
if "fixed" in data:
388+
births = self.params[data["birth"]]
389+
else:
390+
births = population[comp_id] * self.params[data["birth"]]
391+
416392
diff[comp_id] += births
417393

418394
if "death" in data:
419-
deaths = population[comp_id] * data["death"]
395+
deaths = population[comp_id] * self.params[data["death"]]
420396
diff[comp_id] -= deaths
421397

422398
return diff
@@ -543,7 +519,7 @@ def simulate(
543519
source = pos[comp]
544520
target = pos[node_j]
545521

546-
rate = self._compute_rate(data["rate"]) # self.params[data["rate"]]
522+
rate = self.params[data["rate"]] # self.params[data["rate"]]
547523

548524
if "start" in data and data["start"] >= t:
549525
continue
@@ -561,7 +537,7 @@ def simulate(
561537

562538
prob[source] = 1 - np.sum(prob)
563539

564-
delta = random.multinomial(pop[source], prob)
540+
delta = self.rng.multinomial(pop[source], prob)
565541
delta[source] = 0
566542

567543
changes = np.sum(delta)
@@ -580,11 +556,15 @@ def simulate(
580556
comp_id = pos[comp]
581557

582558
if "birth" in data:
583-
births = self.rng.binomial(pop[comp_id], self._compute_rate(data["birth"]))
559+
births = self.rng.binomial(
560+
pop[comp_id], self.params[data["birth"]]
561+
)
584562
new_pop[comp_id] += births
585563

586564
if "death" in data:
587-
deaths = self.rng.binomial(pop[comp_id], self._compute_rate(data["death"]))
565+
deaths = self.rng.binomial(
566+
pop[comp_id], self.params[data["death"]]
567+
)
588568
new_pop[comp_id] -= deaths
589569

590570
values.append(new_pop)
@@ -886,7 +866,7 @@ def R0(self) -> Union[float, None]:
886866

887867
try:
888868
for node_i, node_j, data in self.transitions.edges(data=True):
889-
rate = self._compute_rate(data["rate"]) # self.params[data["rate"]]
869+
rate = self.params[data["rate"]] # self.params[data["rate"]]
890870

891871
if "agent" in data:
892872
target = pos[node_j]

src/epidemik/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"NotInitialized",
1111
"get_cache_directory",
1212
"NotImplementedError",
13+
"Parameters",
1314
]
1415

1516

@@ -32,6 +33,62 @@ class NotImplementedError(Exception):
3233
EPI_COLORS["D"] = "#8b8b8b"
3334

3435

36+
class Parameters(dict):
37+
def __init__(self):
38+
super().__init__()
39+
self.globals = {"__builtins__": None}
40+
41+
def __setitem__(self, key, value):
42+
self.define_parameters(**{key: value})
43+
44+
def __getitem__(self, key):
45+
return self.compute_parameter(key)
46+
47+
def define_parameters(self, **kwargs) -> None:
48+
"""
49+
Define one or more parameter for the model
50+
51+
Parameters:
52+
- kwargs: keyword arguments
53+
Named parameters for the model
54+
55+
Returns:
56+
None
57+
"""
58+
for key, value in kwargs.items():
59+
if isinstance(value, str):
60+
try:
61+
# Convert floats written as strings to float
62+
value = float(value.strip())
63+
except:
64+
pass
65+
66+
super().__setitem__(key, value)
67+
68+
def compute_parameter(self, param: Union[str]) -> float:
69+
"""
70+
Compute the rate from a string
71+
72+
Parameters:
73+
- rate: string
74+
Rate of the transition
75+
76+
Returns:
77+
float
78+
The computed rate
79+
"""
80+
import logging
81+
82+
if param in self.keys():
83+
if isinstance(super().__getitem__(param), (int, float)):
84+
return super().__getitem__(param)
85+
else:
86+
try:
87+
return eval(super().__getitem__(param), self.globals, self)
88+
except Exception as e:
89+
logging.error(f"Error computing parameter {param}: {e}")
90+
return None
91+
3592
def get_cache_directory():
3693
"""
3794
Return the location of the cache directory for the current platform.

0 commit comments

Comments
 (0)