@@ -67,6 +67,7 @@ def add_interaction(
6767 target : str ,
6868 agent : str ,
6969 rate : Union [None , float , str ] = None ,
70+ norm = True ,
7071 ** rates ,
7172 ) -> None :
7273 """
@@ -79,7 +80,11 @@ def add_interaction(
7980 Name of the target compartment
8081 - agent: string
8182 Name of the agent
82- - params: string
83+ - rate: float, str, None
84+ Rate of the interaction
85+ - norm: bool
86+ Whether to normalize the transition rate or not
87+ - rates:
8388 Named parameters for the interaction
8489
8590 Returns:
@@ -95,7 +100,10 @@ def add_interaction(
95100 rates = list (rates .keys ())
96101 rate_key = rates [0 ]
97102
98- self .transitions .add_edge (source , target , agent = agent , rate = rate_key )
103+ if agent not in self .transitions .nodes :
104+ self .transitions .add_node (agent )
105+
106+ self .transitions .add_edge (source , target , agent = agent , rate = rate_key , norm = norm )
99107
100108 def add_spontaneous (
101109 self , source : str , target : str , rate : Union [None , float , str ] = None , ** rates
@@ -162,14 +170,19 @@ def add_viral_generation(
162170 target_rate = rates [1 ]
163171
164172 self .transitions .add_edge (
165- source , target , rate = source_rate
173+ source , target , rate = source_rate , viral_source = True , viral_target = False ,
166174 )
167175 self .transitions .add_edge (
168- source , target , rate = target_rate
176+ source , target , rate = target_rate , viral_source = False , viral_target = True
169177 )
170178
171179 def add_birth_rate (
172- self , comps : Union [List , None ] = None , rate : Union [float , None ] = None , fixed = False , ** rates
180+ self ,
181+ comps : Union [List , None ] = None ,
182+ rate : Union [float , None ] = None ,
183+ fixed = False ,
184+ global_rate = True ,
185+ ** rates
173186 ) -> None :
174187 """
175188 Add a birth rate to one or more compartments
@@ -200,9 +213,13 @@ def add_birth_rate(
200213
201214 self .transitions .nodes [comp ]["birth" ] = rate_key
202215 self .transitions .nodes [comp ]["fixed" ] = fixed
216+ self .transitions .nodes [comp ]["global" ] = global_rate
203217
204218 def add_death_rate (
205- self , comps : Union [List , None ] = None , rate : Union [None , float ] = None , ** rates
219+ self ,
220+ comps : Union [List , None ] = None ,
221+ rate : Union [None , float ] = None ,
222+ ** rates
206223 ) -> None :
207224 """
208225 Add a birth rate to one or more compartments
@@ -327,10 +344,10 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
327344 Internal function used by integration routine
328345
329346 Parameters:
330- - population: numpy array
331- Current population of each compartment
332347 - time: float
333348 Current time
349+ - population: numpy array
350+ Current population of each compartment
334351 - pos: dict
335352 Dictionary mapping compartment names to indices
336353
@@ -364,36 +381,45 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
364381
365382 if "agent" in trans :
366383 agent = trans ["agent" ]
384+ rate *= population [pos [agent ]]
367385
368- if self .population is None :
369- rate *= population [pos [agent ]] / N
370- else :
371- rate *= population [pos [agent ]] / N [agent ]
372-
386+ if trans ["norm" ]:
387+ if self .population is None :
388+ rate /= N
389+ else :
390+ rate /= N [agent ]
391+
373392 if self .seasonality is not None :
374393 curr_t = int (time ) % 365
375394 season = float (self .seasonality [curr_t ])
376395 rate *= season
377396
378- diff [pos [source ]] -= rate
379- diff [pos [target ]] += rate
397+ if "viral_source" not in trans or trans ["viral_source" ]:
398+ diff [pos [source ]] -= rate
399+ # Make sure viral generations are asymetric
400+ if "viral_target" not in trans or trans ["viral_target" ]:
401+ diff [pos [target ]] += rate
380402
381- # Population dynamics
382- if self .demographics :
383- for comp , data in self .transitions .nodes (data = True ):
384- comp_id = pos [comp ]
403+ # Population dynamics
404+ if self .demographics :
405+ for comp , data in self .transitions .nodes (data = True ):
406+ comp_id = pos [comp ]
385407
386- if "birth" in data :
387- if "fixed" in data :
388- births = self .params [data ["birth" ]]
408+ if "birth" in data :
409+ if "fixed" in data and data ["fixed" ]:
410+ births = self .params [data ["birth" ]]
411+ else :
412+ if data ["global" ]:
413+ total_population = population .sum ()
414+ births = total_population * self .params [data ["birth" ]]
389415 else :
390416 births = population [comp_id ] * self .params [data ["birth" ]]
391-
392- diff [comp_id ] += births
393417
394- if "death" in data :
395- deaths = population [comp_id ] * self .params [data ["death" ]]
396- diff [comp_id ] -= deaths
418+ diff [comp_id ] += births
419+
420+ if "death" in data :
421+ deaths = population [comp_id ] * self .params [data ["death" ]]
422+ diff [comp_id ] -= deaths
397423
398424 return diff
399425
@@ -616,7 +642,7 @@ def integrate(
616642
617643 population [pos [comp_age ]] = n [i ]
618644
619- time = np .arange (t_min , t_min + timesteps , 1 )
645+ time = np .arange (t_min , t_min + timesteps )
620646
621647 self .seasonality = seasonality
622648 values = pd .DataFrame (
0 commit comments