Skip to content

Commit 3b7f586

Browse files
committed
fix demographics
1 parent 2ec8007 commit 3b7f586

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

src/epidemik/EpiModel.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def add_birth_rate(
180180
self,
181181
comps: Union[List, None] = None,
182182
rate: Union[float, None] = None,
183-
fixed=False,
184-
global_rate=True,
183+
fixed: bool = False,
184+
global_rate: bool = True,
185185
**rates
186186
) -> None:
187187
"""
@@ -212,13 +212,15 @@ def add_birth_rate(
212212
self.transitions.add_node(comp)
213213

214214
self.transitions.nodes[comp]["birth"] = rate_key
215-
self.transitions.nodes[comp]["fixed"] = fixed
216-
self.transitions.nodes[comp]["global"] = global_rate
215+
self.transitions.nodes[comp]["fixed_birth"] = fixed
216+
self.transitions.nodes[comp]["global_birth"] = global_rate
217217

218218
def add_death_rate(
219219
self,
220220
comps: Union[List, None] = None,
221221
rate: Union[None, float] = None,
222+
fixed: bool = False,
223+
global_rate: bool = False,
222224
**rates
223225
) -> None:
224226
"""
@@ -249,6 +251,8 @@ def add_death_rate(
249251
self.transitions.add_node(comp)
250252

251253
self.transitions.nodes[comp]["death"] = rate_key
254+
self.transitions.nodes[comp]["fixed_death"] = fixed
255+
self.transitions.nodes[comp]["global_death"] = global_rate
252256

253257
def add_vaccination(
254258
self,
@@ -406,10 +410,10 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
406410
comp_id = pos[comp]
407411

408412
if "birth" in data:
409-
if "fixed" in data and data["fixed"]:
413+
if "fixed_birth" in data and data["fixed_birth"]:
410414
births = self.params[data["birth"]]
411415
else:
412-
if data["global"]:
416+
if data["global_birth"]:
413417
total_population = population.sum()
414418
births = total_population * self.params[data["birth"]]
415419
else:
@@ -418,9 +422,18 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
418422
diff[comp_id] += births
419423

420424
if "death" in data:
421-
deaths = population[comp_id] * self.params[data["death"]]
425+
if "fixed_death" in data and data["fixed_death"]:
426+
deaths = self.params[data["death"]]
427+
else:
428+
if data["global_death"]:
429+
total_population = population.sum()
430+
deaths = total_population * self.params[data["death"]]
431+
else:
432+
deaths = population[comp_id] * self.params[data["death"]]
433+
422434
diff[comp_id] -= deaths
423435

436+
424437
return diff
425438

426439
def plot(

0 commit comments

Comments
 (0)