1515import networkx as nx
1616import numpy as np
1717from numpy import linalg
18- from numpy import random
1918import scipy .integrate
2019import pandas as pd
2120import matplotlib .pyplot as plt
@@ -49,7 +48,7 @@ def __init__(self, compartments=None, seed=None, rng=None):
4948 self .population = None
5049 self .orig_comps = None
5150 self .demographics = False
52- self .params = {}
51+ self .params = utils . Parameters ()
5352
5453 if seed is None :
5554 seed = int (time .time ()) + os .getpid ()
@@ -62,24 +61,6 @@ def __init__(self, compartments=None, seed=None, rng=None):
6261 if compartments is not None :
6362 self .transitions .add_nodes_from ([comp for comp in compartments ])
6463
65- def _compute_rate (self , rate : Union [str , float ]) -> float :
66- """
67- Compute the rate from a string
68-
69- Parameters:
70- - rate: string
71- Rate of the transition
72-
73- Returns:
74- float
75- The computed rate
76- """
77- if rate in self .params :
78- if isinstance (self .params [rate ], float ):
79- return self .params [rate ]
80- else :
81- return eval (self .params [rate ], {}, self .params )
82-
8364 def add_interaction (
8465 self ,
8566 source : str ,
@@ -108,9 +89,9 @@ def add_interaction(
10889 if rate is not None :
10990 count = len (self .params ) + 1
11091 rate_key = "rate" + str (count )
111- self .define_parameters ( rate_key = rate )
92+ self .params [ rate_key ] = rate
11293 else :
113- self .define_parameters (** rates )
94+ self .params . define_parameters (** rates )
11495 rates = list (rates .keys ())
11596 rate_key = rates [0 ]
11697
@@ -137,21 +118,22 @@ def add_spontaneous(
137118 if rate is not None :
138119 count = len (self .params ) + 1
139120 rate_key = "rate" + str (count )
140- self .define_parameters ( rate_key = rate )
121+ self .params [ rate_key , rate ]
141122 else :
142- self .define_parameters (** rates )
123+ self .params . define_parameters (** rates )
143124 rates = list (rates .keys ())
144125 rate_key = rates [0 ]
145126
146127 self .transitions .add_edge (source , target , rate = rate_key )
147128
148- def add_viral_generation (self ,
149- source :str ,
150- target :str ,
151- source_rate :Union [float , str , None ] = None ,
152- target_rate :Union [float , str , None ] = None ,
153- ** rates
154- ) -> None :
129+ def add_viral_generation (
130+ self ,
131+ source : str ,
132+ target : str ,
133+ source_rate : Union [float , str , None ] = None ,
134+ target_rate : Union [float , str , None ] = None ,
135+ ** rates ,
136+ ) -> None :
155137 """
156138 Add a viral generation transition
157139
@@ -169,20 +151,25 @@ def add_viral_generation(self,
169151 if source_rate is not None and target_rate is not None :
170152 count = len (self .params ) + 1
171153 rate_key = "rate" + str (count )
172- self .define_parameters ( rate_key = source_rate )
154+ self .params [ rate_key ] = source_rate
173155
174- rate_key = "rate" + str (count + 1 )
175- self .define_parameters ( ** { rate_key : target_rate })
176- else :
177- self .define_parameters (** rates )
156+ rate_key = "rate" + str (count + 1 )
157+ self .params [ rate_key ] = target_rate
158+ else :
159+ self .params . define_parameters (** rates )
178160 rates = list (rates .keys ())
179161 source_rate = rates [0 ]
180162 target_rate = rates [1 ]
181163
182- self .transitions .add_edge (source , target , source_rate = source_rate , target_rate = target_rate )
164+ self .transitions .add_edge (
165+ source , target , rate = source_rate
166+ )
167+ self .transitions .add_edge (
168+ source , target , rate = target_rate
169+ )
183170
184171 def add_birth_rate (
185- self , comps : Union [List , None ] = None , rate : Union [float , None ] = None , ** rates
172+ self , comps : Union [List , None ] = None , rate : Union [float , None ] = None , fixed = False , ** rates
186173 ) -> None :
187174 """
188175 Add a birth rate to one or more compartments
@@ -212,6 +199,7 @@ def add_birth_rate(
212199 self .transitions .add_node (comp )
213200
214201 self .transitions .nodes [comp ]["birth" ] = rate_key
202+ self .transitions .nodes [comp ]["fixed" ] = fixed
215203
216204 def add_death_rate (
217205 self , comps : Union [List , None ] = None , rate : Union [None , float ] = None , ** rates
@@ -246,7 +234,12 @@ def add_death_rate(
246234 self .transitions .nodes [comp ]["death" ] = rate_key
247235
248236 def add_vaccination (
249- self , source : str , target : str , start : int , rate : Union [None , float , str ], ** rates
237+ self ,
238+ source : str ,
239+ target : str ,
240+ start : int ,
241+ rate : Union [None , float , str ],
242+ ** rates ,
250243 ) -> None :
251244 """
252245 Add a vaccination transition between two compartments
@@ -275,27 +268,6 @@ def add_vaccination(
275268
276269 self .transitions .add_edge (source , target , rate = rate_key , start = start )
277270
278- def define_parameters (self , ** kwargs ) -> None :
279- """
280- Define one or more parameter for the model
281-
282- Parameters:
283- - kwargs: keyword arguments
284- Named parameters for the model
285-
286- Returns:
287- None
288- """
289- for key , value in kwargs .items ():
290- if isinstance (value , str ):
291- try :
292- # Convert floats written as strings to float
293- value = float (value .strip ())
294- except :
295- pass
296-
297- self .params [key ] = value
298-
299271 def add_age_structure (self , matrix : List , population : List ) -> List [List ]:
300272 """
301273 Add a vaccination transition between two compartments
@@ -384,7 +356,7 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
384356 target = edge [1 ]
385357 trans = edge [2 ]
386358
387- rate_val = self ._compute_rate ( trans [ "rate" ]) #self. params[trans["rate"]]
359+ rate_val = self .params [trans ["rate" ]]
388360 rate = rate_val * population [pos [source ]]
389361
390362 if "start" in trans and trans ["start" ] >= time :
@@ -412,11 +384,15 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
412384 comp_id = pos [comp ]
413385
414386 if "birth" in data :
415- births = population [comp_id ] * data ["birth" ]
387+ if "fixed" in data :
388+ births = self .params [data ["birth" ]]
389+ else :
390+ births = population [comp_id ] * self .params [data ["birth" ]]
391+
416392 diff [comp_id ] += births
417393
418394 if "death" in data :
419- deaths = population [comp_id ] * data ["death" ]
395+ deaths = population [comp_id ] * self . params [ data ["death" ] ]
420396 diff [comp_id ] -= deaths
421397
422398 return diff
@@ -543,7 +519,7 @@ def simulate(
543519 source = pos [comp ]
544520 target = pos [node_j ]
545521
546- rate = self ._compute_rate ( data ["rate" ]) # self.params[data["rate"]]
522+ rate = self .params [ data ["rate" ]] # self.params[data["rate"]]
547523
548524 if "start" in data and data ["start" ] >= t :
549525 continue
@@ -561,7 +537,7 @@ def simulate(
561537
562538 prob [source ] = 1 - np .sum (prob )
563539
564- delta = random .multinomial (pop [source ], prob )
540+ delta = self . rng .multinomial (pop [source ], prob )
565541 delta [source ] = 0
566542
567543 changes = np .sum (delta )
@@ -580,11 +556,15 @@ def simulate(
580556 comp_id = pos [comp ]
581557
582558 if "birth" in data :
583- births = self .rng .binomial (pop [comp_id ], self ._compute_rate (data ["birth" ]))
559+ births = self .rng .binomial (
560+ pop [comp_id ], self .params [data ["birth" ]]
561+ )
584562 new_pop [comp_id ] += births
585563
586564 if "death" in data :
587- deaths = self .rng .binomial (pop [comp_id ], self ._compute_rate (data ["death" ]))
565+ deaths = self .rng .binomial (
566+ pop [comp_id ], self .params [data ["death" ]]
567+ )
588568 new_pop [comp_id ] -= deaths
589569
590570 values .append (new_pop )
@@ -886,7 +866,7 @@ def R0(self) -> Union[float, None]:
886866
887867 try :
888868 for node_i , node_j , data in self .transitions .edges (data = True ):
889- rate = self ._compute_rate ( data ["rate" ]) # self.params[data["rate"]]
869+ rate = self .params [ data ["rate" ]] # self.params[data["rate"]]
890870
891871 if "agent" in data :
892872 target = pos [node_j ]
0 commit comments