@@ -62,7 +62,32 @@ def __init__(self, compartments=None, seed=None, rng=None):
6262 if compartments is not None :
6363 self .transitions .add_nodes_from ([comp for comp in compartments ])
6464
65- def add_interaction (self , source : str , target : str , agent : str , ** rates ) -> None :
65+ def _compute_rate (self , rate : Union [str , float ]) -> float :
66+ """
67+ Compute the rate from a string
68+
69+ Parameters:
70+ - rate: string
71+ Rate of the transition
72+
73+ Returns:
74+ float
75+ The computed rate
76+ """
77+ if rate in self .params :
78+ if isinstance (self .params [rate ], float ):
79+ return self .params [rate ]
80+ else :
81+ return eval (self .params [rate ], {}, self .params )
82+
83+ def add_interaction (
84+ self ,
85+ source : str ,
86+ target : str ,
87+ agent : str ,
88+ rate : Union [None , float , str ] = None ,
89+ ** rates ,
90+ ) -> None :
6691 """
6792 Add an interaction between two compartments
6893
@@ -80,12 +105,20 @@ def add_interaction(self, source: str, target: str, agent: str, **rates) -> None
80105 None
81106 """
82107
83- self .params .update (rates )
84- rates = list (rates .keys ())
108+ if rate is not None :
109+ count = len (self .params ) + 1
110+ rate_key = "rate" + str (count )
111+ self .define_parameters (rate_key = rate )
112+ else :
113+ self .define_parameters (** rates )
114+ rates = list (rates .keys ())
115+ rate_key = rates [0 ]
85116
86- self .transitions .add_edge (source , target , agent = agent , rate = rates [ 0 ] )
117+ self .transitions .add_edge (source , target , agent = agent , rate = rate_key )
87118
88- def add_spontaneous (self , source : str , target : str , ** rates ) -> None :
119+ def add_spontaneous (
120+ self , source : str , target : str , rate : Union [None , float , str ] = None , ** rates
121+ ) -> None :
89122 """
90123 Add a spontaneous transition between two compartments
91124
@@ -101,12 +134,56 @@ def add_spontaneous(self, source: str, target: str, **rates) -> None:
101134 None
102135 """
103136
104- self .params .update (rates )
105- rates = list (rates .keys ())
137+ if rate is not None :
138+ count = len (self .params ) + 1
139+ rate_key = "rate" + str (count )
140+ self .define_parameters (rate_key = rate )
141+ else :
142+ self .define_parameters (** rates )
143+ rates = list (rates .keys ())
144+ rate_key = rates [0 ]
145+
146+ self .transitions .add_edge (source , target , rate = rate_key )
147+
148+ def add_viral_generation (self ,
149+ source :str ,
150+ target :str ,
151+ source_rate :Union [float , str , None ] = None ,
152+ target_rate :Union [float , str , None ] = None ,
153+ ** rates
154+ ) -> None :
155+ """
156+ Add a viral generation transition
106157
107- self .transitions .add_edge (source , target , rate = rates [0 ])
158+ Parameters:
159+ - source: string
160+ Name of the source compartment
161+ - target: string
162+ Name of the target compartment
163+ - source_rate: float
164+ Rate of destruction of infected cells
165+ - target_rate: float
166+ Rate of creation of viral particles
167+ """
108168
109- def add_birth_rate (self , rate : float , comps : Union [List , None ] = None ) -> None :
169+ if source_rate is not None and target_rate is not None :
170+ count = len (self .params ) + 1
171+ rate_key = "rate" + str (count )
172+ self .define_parameters (rate_key = source_rate )
173+
174+ rate_key = "rate" + str (count + 1 )
175+ self .define_parameters (** {rate_key : target_rate })
176+ else :
177+ self .define_parameters (** rates )
178+ rates = list (rates .keys ())
179+ source_rate = rates [0 ]
180+ target_rate = rates [1 ]
181+
182+ self .transitions .add_edge (source , target , source_rate = source_rate , target_rate = target_rate )
183+
184+ def add_birth_rate (
185+ self , comps : Union [List , None ] = None , rate : Union [float , None ] = None , ** rates
186+ ) -> None :
110187 """
111188 Add a birth rate to one or more compartments
112189
@@ -119,13 +196,26 @@ def add_birth_rate(self, rate: float, comps: Union[List, None] = None) -> None:
119196 """
120197 self .demographics = True
121198
199+ if rate is not None :
200+ count = len (self .params ) + 1
201+ rate_key = "rate" + str (count )
202+ self .params [rate_key ] = rate
203+ else :
204+ self .params .update (rates )
205+ rate_key = list (rates .keys ())[0 ]
206+
122207 if comps is None :
123208 comps = self .transitions .nodes ()
124209
125210 for comp in comps :
126- self .transitions .nodes [comp ]["birth" ] = rate
211+ if comp not in self .transitions .nodes :
212+ self .transitions .add_node (comp )
127213
128- def add_death_rate (self , rate : float , comps : Union [List , None ] = None ) -> None :
214+ self .transitions .nodes [comp ]["birth" ] = rate_key
215+
216+ def add_death_rate (
217+ self , comps : Union [List , None ] = None , rate : Union [None , float ] = None , ** rates
218+ ) -> None :
129219 """
130220 Add a birth rate to one or more compartments
131221
@@ -138,14 +228,25 @@ def add_death_rate(self, rate: float, comps: Union[List, None] = None) -> None:
138228 """
139229 self .demographics = True
140230
231+ if rate is not None :
232+ count = len (self .params ) + 1
233+ rate_key = "rate" + str (count )
234+ self .params [rate_key ] = rate
235+ else :
236+ self .params .update (rates )
237+ rate_key = list (rates .keys ())[0 ]
238+
141239 if comps is None :
142240 comps = self .transitions .nodes ()
143241
144242 for comp in comps :
145- self .transitions .nodes [comp ]["death" ] = rate
243+ if comp not in self .transitions .nodes :
244+ self .transitions .add_node (comp )
245+
246+ self .transitions .nodes [comp ]["death" ] = rate_key
146247
147248 def add_vaccination (
148- self , source : str , target : str , rate : float , start : int
249+ self , source : str , target : str , start : int , rate : Union [ None , float , str ], ** rates
149250 ) -> None :
150251 """
151252 Add a vaccination transition between two compartments
@@ -163,7 +264,37 @@ def add_vaccination(
163264 Returns:
164265 None
165266 """
166- self .transitions .add_edge (source , target , rate = rate , start = start )
267+
268+ if rate is not None :
269+ count = len (self .params ) + 1
270+ rate_key = "rate" + str (count )
271+ self .params [rate_key ] = rate
272+ else :
273+ self .params .update (rates )
274+ rate_key = list (rates .keys ())[0 ]
275+
276+ self .transitions .add_edge (source , target , rate = rate_key , start = start )
277+
278+ def define_parameters (self , ** kwargs ) -> None :
279+ """
280+ Define one or more parameter for the model
281+
282+ Parameters:
283+ - kwargs: keyword arguments
284+ Named parameters for the model
285+
286+ Returns:
287+ None
288+ """
289+ for key , value in kwargs .items ():
290+ if isinstance (value , str ):
291+ try :
292+ # Convert floats written as strings to float
293+ value = float (value .strip ())
294+ except :
295+ pass
296+
297+ self .params [key ] = value
167298
168299 def add_age_structure (self , matrix : List , population : List ) -> List [List ]:
169300 """
@@ -253,7 +384,7 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndarr
253384 target = edge [1 ]
254385 trans = edge [2 ]
255386
256- rate_val = self .params [trans ["rate" ]]
387+ rate_val = self ._compute_rate ( trans [ "rate" ]) #self. params[trans["rate"]]
257388 rate = rate_val * population [pos [source ]]
258389
259390 if "start" in trans and trans ["start" ] >= time :
@@ -412,7 +543,7 @@ def simulate(
412543 source = pos [comp ]
413544 target = pos [node_j ]
414545
415- rate = self .params [data ["rate" ]]
546+ rate = self ._compute_rate ( data [ "rate" ]) # self. params[data["rate"]]
416547
417548 if "start" in data and data ["start" ] >= t :
418549 continue
@@ -449,11 +580,11 @@ def simulate(
449580 comp_id = pos [comp ]
450581
451582 if "birth" in data :
452- births = self .rng .binomial (pop [comp_id ], data ["birth" ])
583+ births = self .rng .binomial (pop [comp_id ], self . _compute_rate ( data ["birth" ]) )
453584 new_pop [comp_id ] += births
454585
455586 if "death" in data :
456- deaths = self .rng .binomial (pop [comp_id ], data ["death" ])
587+ deaths = self .rng .binomial (pop [comp_id ], self . _compute_rate ( data ["death" ]) )
457588 new_pop [comp_id ] -= deaths
458589
459590 values .append (new_pop )
@@ -561,17 +692,19 @@ def __repr__(self) -> str:
561692
562693 text += "Parameters:\n "
563694 for rate , value in self .params .items ():
564- text += " %s : %f \n " % (rate , value )
695+ text += " %s : %s \n " % (rate , value )
565696 text += "\n \n Transitions:\n "
566697
567698 for edge in self .transitions .edges (data = True ):
568699 source = edge [0 ]
569700 target = edge [1 ]
570701 trans = edge [2 ]
571702
703+ # Interaction
572704 if "agent" in trans :
573705 agent = trans ["agent" ]
574706 text += " - %s + %s = %s %s\n " % (source , agent , target , trans ["rate" ])
707+ # Vaccination
575708 elif "start" in trans :
576709 start = trans ["start" ]
577710 text += " - %s -> %s %s starting at %s days\n " % (
@@ -580,9 +713,26 @@ def __repr__(self) -> str:
580713 rate ,
581714 start ,
582715 )
716+ # Viral transition
717+ elif "source_rate" in trans :
718+ text += " - %s => %s %s %s" % (
719+ source ,
720+ target ,
721+ trans ["source_rate" ],
722+ trans ["target_rate" ],
723+ )
724+ # Spontaneous
583725 else :
584726 text += " - %s -> %s %s\n " % (source , target , rate )
585727
728+ if self .demographics :
729+ text += "\n \n Demographics:\n "
730+ for comp , data in self .transitions .nodes (data = True ):
731+ if "birth" in data :
732+ text += " - -> %s: %s # birth rate\n " % (comp , data ["birth" ])
733+ if "death" in data :
734+ text += " - %s ->: %s # death rate\n " % (comp , data ["death" ])
735+
586736 R0 = self .R0 ()
587737
588738 if R0 is not None :
@@ -736,7 +886,7 @@ def R0(self) -> Union[float, None]:
736886
737887 try :
738888 for node_i , node_j , data in self .transitions .edges (data = True ):
739- rate = self .params [data ["rate" ]]
889+ rate = self ._compute_rate ( data [ "rate" ]) # self. params[data["rate"]]
740890
741891 if "agent" in data :
742892 target = pos [node_j ]
0 commit comments