@@ -180,8 +180,8 @@ def add_birth_rate(
180180 self ,
181181 comps : Union [List , None ] = None ,
182182 rate : Union [float , None ] = None ,
183- fixed = False ,
184- global_rate = True ,
183+ fixed : bool = False ,
184+ global_rate : bool = True ,
185185 ** rates
186186 ) -> None :
187187 """
@@ -212,13 +212,15 @@ def add_birth_rate(
212212 self .transitions .add_node (comp )
213213
214214 self .transitions .nodes [comp ]["birth" ] = rate_key
215- self .transitions .nodes [comp ]["fixed " ] = fixed
216- self .transitions .nodes [comp ]["global " ] = global_rate
215+ self .transitions .nodes [comp ]["fixed_birth " ] = fixed
216+ self .transitions .nodes [comp ]["global_birth " ] = global_rate
217217
218218 def add_death_rate (
219219 self ,
220220 comps : Union [List , None ] = None ,
221221 rate : Union [None , float ] = None ,
222+ fixed : bool = False ,
223+ global_rate : bool = False ,
222224 ** rates
223225 ) -> None :
224226 """
@@ -249,6 +251,8 @@ def add_death_rate(
249251 self .transitions .add_node (comp )
250252
251253 self .transitions .nodes [comp ]["death" ] = rate_key
254+ self .transitions .nodes [comp ]["fixed_death" ] = fixed
255+ self .transitions .nodes [comp ]["global_death" ] = global_rate
252256
253257 def add_vaccination (
254258 self ,
@@ -406,10 +410,10 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
406410 comp_id = pos [comp ]
407411
408412 if "birth" in data :
409- if "fixed " in data and data ["fixed " ]:
413+ if "fixed_birth " in data and data ["fixed_birth " ]:
410414 births = self .params [data ["birth" ]]
411415 else :
412- if data ["global " ]:
416+ if data ["global_birth " ]:
413417 total_population = population .sum ()
414418 births = total_population * self .params [data ["birth" ]]
415419 else :
@@ -418,9 +422,18 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
418422 diff [comp_id ] += births
419423
420424 if "death" in data :
421- deaths = population [comp_id ] * self .params [data ["death" ]]
425+ if "fixed_death" in data and data ["fixed_death" ]:
426+ deaths = self .params [data ["death" ]]
427+ else :
428+ if data ["global_death" ]:
429+ total_population = population .sum ()
430+ deaths = total_population * self .params [data ["death" ]]
431+ else :
432+ deaths = population [comp_id ] * self .params [data ["death" ]]
433+
422434 diff [comp_id ] -= deaths
423435
436+
424437 return diff
425438
426439 def plot (
0 commit comments