Skip to content

Commit b83d244

Browse files
committed
fix rate calculation in NetworkEpiModel
1 parent 610d88e commit b83d244

1 file changed

Lines changed: 32 additions & 15 deletions

File tree

src/epidemik/NetworkEpiModel.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# @author Bruno Goncalves
44
######################################################
55

6-
from typing import Union
6+
from typing import Union, Optional
77
import networkx as nx
88
import numpy as np
99
from numpy import linalg
@@ -12,53 +12,70 @@
1212
import matplotlib.pyplot as plt
1313
from .EpiModel import EpiModel
1414
from collections import Counter
15-
from .utils import *
15+
from . import utils
1616

1717

1818
class NetworkEpiModel(EpiModel):
19-
def __init__(self, network, compartments=None):
19+
def __init__(self, network:nx.Graph, compartments: Optional[str]=None):
2020
super(NetworkEpiModel, self).__init__(compartments)
2121
self.network = network
2222
self.kavg_ = 2 * network.number_of_edges() / network.number_of_nodes()
2323
self.spontaneous = {}
2424
self.interactions = {}
25-
self.params = {}
25+
self.params = utils.Parameters()
2626

2727
def integrate(self, timesteps, **kwargs):
2828
raise NotImplementedError("Network Models don't support numerical integration")
2929

3030
def add_interaction(
31-
self, source: str, target: str, agent: str, rescale: bool = False, **rates
31+
self, source: str, target: str, agent: str, rate:Optional[float] = None, rescale: bool = False, **rates
3232
) -> None:
33+
3334
if rescale:
34-
rate /= self.kavg_
35+
if rate:
36+
rate /= self.kavg_
37+
else:
38+
rates_names = list(rates.keys())
39+
rates[rates_names[0]] /= self.kavg_
3540

36-
self.params.update(rates)
37-
rate = list(rates.keys())[0]
3841
super(NetworkEpiModel, self).add_interaction(
39-
source, target, agent=agent, rate=rate
42+
source, target, agent=agent, rate=rate, **rates
4043
)
4144

45+
if rate is not None:
46+
count = len(self.params) + 1
47+
rate_key = "rate" + str(count)
48+
else:
49+
self.params.define_parameters(**rates)
50+
rates = list(rates.keys())
51+
rate_key = rates[0]
52+
4253
if source not in self.interactions:
4354
self.interactions[source] = {}
4455

4556
if target not in self.interactions[source]:
4657
self.interactions[source] = {}
4758

48-
self.interactions[source][agent] = {"target": target, "rate": rate}
59+
self.interactions[source][agent] = {"target": target, "rate": rate_key}
60+
61+
def add_spontaneous(self, source: str, target: str, rate: Optional[float], **rates):
62+
super(NetworkEpiModel, self).add_spontaneous(source, target, rate=rate, **rates)
4963

50-
def add_spontaneous(self, source: str, target: str, **rates):
51-
self.params.update(rates)
52-
rate = list(rates.keys())[0]
64+
if rate is not None:
65+
count = len(self.params) + 1
66+
rate_key = "rate" + str(count)
67+
else:
68+
self.params.define_parameters(**rates)
69+
rates = list(rates.keys())
70+
rate_key = rates[0]
5371

54-
super(NetworkEpiModel, self).add_spontaneous(source, target, rate=rate)
5572
if source not in self.spontaneous:
5673
self.spontaneous[source] = {}
5774

5875
if target not in self.spontaneous[source]:
5976
self.spontaneous[source] = {}
6077

61-
self.spontaneous[source][target] = rate
78+
self.spontaneous[source][target] = rate_key
6279

6380
def simulate(self, timesteps: int, seeds, **kwargs) -> None:
6481
"""Stochastically simulate the epidemic model"""

0 commit comments

Comments
 (0)