Skip to content

Commit 663ad79

Browse files
committed
evaluate parameters
1 parent a1f6b62 commit 663ad79

1 file changed

Lines changed: 170 additions & 20 deletions

File tree

src/epidemik/EpiModel.py

Lines changed: 170 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,32 @@ def __init__(self, compartments=None, seed=None, rng=None):
6262
if compartments is not None:
6363
self.transitions.add_nodes_from([comp for comp in compartments])
6464

65-
def add_interaction(self, source: str, target: str, agent: str, **rates) -> None:
65+
def _compute_rate(self, rate: Union[str, float]) -> float:
66+
"""
67+
Compute the rate from a string
68+
69+
Parameters:
70+
- rate: string
71+
Rate of the transition
72+
73+
Returns:
74+
float
75+
The computed rate
76+
"""
77+
if rate in self.params:
78+
if isinstance(self.params[rate], float):
79+
return self.params[rate]
80+
else:
81+
return eval(self.params[rate], {}, self.params)
82+
83+
def add_interaction(
84+
self,
85+
source: str,
86+
target: str,
87+
agent: str,
88+
rate: Union[None, float, str] = None,
89+
**rates,
90+
) -> None:
6691
"""
6792
Add an interaction between two compartments
6893
@@ -80,12 +105,20 @@ def add_interaction(self, source: str, target: str, agent: str, **rates) -> None
80105
None
81106
"""
82107

83-
self.params.update(rates)
84-
rates = list(rates.keys())
108+
if rate is not None:
109+
count = len(self.params) + 1
110+
rate_key = "rate" + str(count)
111+
self.define_parameters(rate_key=rate)
112+
else:
113+
self.define_parameters(**rates)
114+
rates = list(rates.keys())
115+
rate_key = rates[0]
85116

86-
self.transitions.add_edge(source, target, agent=agent, rate=rates[0])
117+
self.transitions.add_edge(source, target, agent=agent, rate=rate_key)
87118

88-
def add_spontaneous(self, source: str, target: str, **rates) -> None:
119+
def add_spontaneous(
120+
self, source: str, target: str, rate: Union[None, float, str] = None, **rates
121+
) -> None:
89122
"""
90123
Add a spontaneous transition between two compartments
91124
@@ -101,12 +134,56 @@ def add_spontaneous(self, source: str, target: str, **rates) -> None:
101134
None
102135
"""
103136

104-
self.params.update(rates)
105-
rates = list(rates.keys())
137+
if rate is not None:
138+
count = len(self.params) + 1
139+
rate_key = "rate" + str(count)
140+
self.define_parameters(rate_key=rate)
141+
else:
142+
self.define_parameters(**rates)
143+
rates = list(rates.keys())
144+
rate_key = rates[0]
145+
146+
self.transitions.add_edge(source, target, rate=rate_key)
147+
148+
def add_viral_generation(self,
149+
source:str,
150+
target:str,
151+
source_rate:Union[float, str, None] = None,
152+
target_rate:Union[float, str, None] = None,
153+
**rates
154+
) -> None:
155+
"""
156+
Add a viral generation transition
106157
107-
self.transitions.add_edge(source, target, rate=rates[0])
158+
Parameters:
159+
- source: string
160+
Name of the source compartment
161+
- target: string
162+
Name of the target compartment
163+
- source_rate: float
164+
Rate of destruction of infected cells
165+
- target_rate: float
166+
Rate of creation of viral particles
167+
"""
108168

109-
def add_birth_rate(self, rate: float, comps: Union[List, None] = None) -> None:
169+
if source_rate is not None and target_rate is not None:
170+
count = len(self.params) + 1
171+
rate_key = "rate" + str(count)
172+
self.define_parameters(rate_key = source_rate)
173+
174+
rate_key = "rate" + str(count+1)
175+
self.define_parameters(**{rate_key: target_rate})
176+
else:
177+
self.define_parameters(**rates)
178+
rates = list(rates.keys())
179+
source_rate = rates[0]
180+
target_rate = rates[1]
181+
182+
self.transitions.add_edge(source, target, source_rate=source_rate, target_rate=target_rate)
183+
184+
def add_birth_rate(
185+
self, comps: Union[List, None] = None, rate: Union[float, None] = None, **rates
186+
) -> None:
110187
"""
111188
Add a birth rate to one or more compartments
112189
@@ -119,13 +196,26 @@ def add_birth_rate(self, rate: float, comps: Union[List, None] = None) -> None:
119196
"""
120197
self.demographics = True
121198

199+
if rate is not None:
200+
count = len(self.params) + 1
201+
rate_key = "rate" + str(count)
202+
self.params[rate_key] = rate
203+
else:
204+
self.params.update(rates)
205+
rate_key = list(rates.keys())[0]
206+
122207
if comps is None:
123208
comps = self.transitions.nodes()
124209

125210
for comp in comps:
126-
self.transitions.nodes[comp]["birth"] = rate
211+
if comp not in self.transitions.nodes:
212+
self.transitions.add_node(comp)
127213

128-
def add_death_rate(self, rate: float, comps: Union[List, None] = None) -> None:
214+
self.transitions.nodes[comp]["birth"] = rate_key
215+
216+
def add_death_rate(
217+
self, comps: Union[List, None] = None, rate: Union[None, float] = None, **rates
218+
) -> None:
129219
"""
130220
Add a birth rate to one or more compartments
131221
@@ -138,14 +228,25 @@ def add_death_rate(self, rate: float, comps: Union[List, None] = None) -> None:
138228
"""
139229
self.demographics = True
140230

231+
if rate is not None:
232+
count = len(self.params) + 1
233+
rate_key = "rate" + str(count)
234+
self.params[rate_key] = rate
235+
else:
236+
self.params.update(rates)
237+
rate_key = list(rates.keys())[0]
238+
141239
if comps is None:
142240
comps = self.transitions.nodes()
143241

144242
for comp in comps:
145-
self.transitions.nodes[comp]["death"] = rate
243+
if comp not in self.transitions.nodes:
244+
self.transitions.add_node(comp)
245+
246+
self.transitions.nodes[comp]["death"] = rate_key
146247

147248
def add_vaccination(
148-
self, source: str, target: str, rate: float, start: int
249+
self, source: str, target: str, start: int, rate: Union[None, float, str], **rates
149250
) -> None:
150251
"""
151252
Add a vaccination transition between two compartments
@@ -163,7 +264,37 @@ def add_vaccination(
163264
Returns:
164265
None
165266
"""
166-
self.transitions.add_edge(source, target, rate=rate, start=start)
267+
268+
if rate is not None:
269+
count = len(self.params) + 1
270+
rate_key = "rate" + str(count)
271+
self.params[rate_key] = rate
272+
else:
273+
self.params.update(rates)
274+
rate_key = list(rates.keys())[0]
275+
276+
self.transitions.add_edge(source, target, rate=rate_key, start=start)
277+
278+
def define_parameters(self, **kwargs) -> None:
279+
"""
280+
Define one or more parameter for the model
281+
282+
Parameters:
283+
- kwargs: keyword arguments
284+
Named parameters for the model
285+
286+
Returns:
287+
None
288+
"""
289+
for key, value in kwargs.items():
290+
if isinstance(value, str):
291+
try:
292+
# Convert floats written as strings to float
293+
value = float(value.strip())
294+
except:
295+
pass
296+
297+
self.params[key] = value
167298

168299
def add_age_structure(self, matrix: List, population: List) -> List[List]:
169300
"""
@@ -253,7 +384,7 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
253384
target = edge[1]
254385
trans = edge[2]
255386

256-
rate_val = self.params[trans["rate"]]
387+
rate_val = self._compute_rate(trans["rate"]) #self.params[trans["rate"]]
257388
rate = rate_val * population[pos[source]]
258389

259390
if "start" in trans and trans["start"] >= time:
@@ -412,7 +543,7 @@ def simulate(
412543
source = pos[comp]
413544
target = pos[node_j]
414545

415-
rate = self.params[data["rate"]]
546+
rate = self._compute_rate(data["rate"]) # self.params[data["rate"]]
416547

417548
if "start" in data and data["start"] >= t:
418549
continue
@@ -449,11 +580,11 @@ def simulate(
449580
comp_id = pos[comp]
450581

451582
if "birth" in data:
452-
births = self.rng.binomial(pop[comp_id], data["birth"])
583+
births = self.rng.binomial(pop[comp_id], self._compute_rate(data["birth"]))
453584
new_pop[comp_id] += births
454585

455586
if "death" in data:
456-
deaths = self.rng.binomial(pop[comp_id], data["death"])
587+
deaths = self.rng.binomial(pop[comp_id], self._compute_rate(data["death"]))
457588
new_pop[comp_id] -= deaths
458589

459590
values.append(new_pop)
@@ -561,17 +692,19 @@ def __repr__(self) -> str:
561692

562693
text += "Parameters:\n"
563694
for rate, value in self.params.items():
564-
text += " %s : %f\n" % (rate, value)
695+
text += " %s : %s\n" % (rate, value)
565696
text += "\n\nTransitions:\n"
566697

567698
for edge in self.transitions.edges(data=True):
568699
source = edge[0]
569700
target = edge[1]
570701
trans = edge[2]
571702

703+
# Interaction
572704
if "agent" in trans:
573705
agent = trans["agent"]
574706
text += " - %s + %s = %s %s\n" % (source, agent, target, trans["rate"])
707+
# Vaccination
575708
elif "start" in trans:
576709
start = trans["start"]
577710
text += " - %s -> %s %s starting at %s days\n" % (
@@ -580,9 +713,26 @@ def __repr__(self) -> str:
580713
rate,
581714
start,
582715
)
716+
# Viral transition
717+
elif "source_rate" in trans:
718+
text += " - %s => %s %s %s" % (
719+
source,
720+
target,
721+
trans["source_rate"],
722+
trans["target_rate"],
723+
)
724+
# Spontaneous
583725
else:
584726
text += " - %s -> %s %s\n" % (source, target, rate)
585727

728+
if self.demographics:
729+
text += "\n\nDemographics:\n"
730+
for comp, data in self.transitions.nodes(data=True):
731+
if "birth" in data:
732+
text += " - -> %s: %s # birth rate\n" % (comp, data["birth"])
733+
if "death" in data:
734+
text += " - %s ->: %s # death rate\n" % (comp, data["death"])
735+
586736
R0 = self.R0()
587737

588738
if R0 is not None:
@@ -736,7 +886,7 @@ def R0(self) -> Union[float, None]:
736886

737887
try:
738888
for node_i, node_j, data in self.transitions.edges(data=True):
739-
rate = self.params[data["rate"]]
889+
rate = self._compute_rate(data["rate"]) # self.params[data["rate"]]
740890

741891
if "agent" in data:
742892
target = pos[node_j]

0 commit comments

Comments
 (0)