Skip to content

Commit d6ae3f3

Browse files
committed
add datasets and fix Example.ipynb
1 parent a984113 commit d6ae3f3

9 files changed

Lines changed: 1605 additions & 1355 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

Example.ipynb

Lines changed: 115 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 3,
12+
"execution_count": 18,
1313
"metadata": {},
1414
"outputs": [
1515
{
@@ -40,15 +40,27 @@
4040
" load_method(mid)"
4141
]
4242
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {},
46+
"source": [
47+
"These are the method_id for the existing datasets."
48+
]
49+
},
4350
{
4451
"cell_type": "code",
45-
"execution_count": 4,
52+
"execution_count": 19,
4653
"metadata": {},
4754
"outputs": [
4855
{
4956
"name": "stdout",
5057
"output_type": "stream",
5158
"text": [
59+
"BlogCatalog\n",
60+
"ICEWS\n",
61+
"Facebook\n",
62+
"DD6\n",
63+
"PPI\n",
5264
"airports\n"
5365
]
5466
}
@@ -61,6 +73,57 @@
6173
" load_dataset(did)"
6274
]
6375
},
76+
{
77+
"cell_type": "markdown",
78+
"metadata": {},
79+
"source": [
80+
"These are the dataset_id for the existing datasets."
81+
]
82+
},
83+
{
84+
"cell_type": "markdown",
85+
"metadata": {},
86+
"source": [
87+
"## Load Dataset"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 20,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"# Get airports datasets\n",
97+
"DataProvider = load_dataset(\"airports\")\n",
98+
"Datasets = DataProvider().get_datasets()\n",
99+
"dataset_graph = DataProvider().load_dataset(Datasets[0])"
100+
]
101+
},
102+
{
103+
"cell_type": "markdown",
104+
"metadata": {},
105+
"source": [
106+
"Note that there are three datasets in the airports dataset.\n",
107+
"\n",
108+
"Datasets\\[0\\] represents the BR-air traffic Dataset\n",
109+
"\n",
110+
"Datasets\\[1\\] represents the EU-air traffic Dataset\n",
111+
"\n",
112+
"Datasets\\[2\\] represents the US-air traffic Dataset"
113+
]
114+
},
115+
{
116+
"cell_type": "code",
117+
"execution_count": 21,
118+
"metadata": {},
119+
"outputs": [],
120+
"source": [
121+
"# Example code for getting the other datasets\n",
122+
"DataProvider = load_dataset(\"Facebook\")\n",
123+
"Facebook_dataset = DataProvider().get_datasets()\n",
124+
"Facebook_graph = DataProvider().load_dataset(Facebook_dataset[0])"
125+
]
126+
},
64127
{
65128
"cell_type": "markdown",
66129
"metadata": {},
@@ -85,13 +148,8 @@
85148
"# Define a hyper-class to load the embedding method\n",
86149
"EmbMethodClass = load_method(\"struc2vec\")\n",
87150
"\n",
88-
"# Get airports datasets\n",
89-
"AirportDataProvider = load_dataset(\"airports\")\n",
90-
"airport_datasets = AirportDataProvider().get_datasets()\n",
91-
"brazil_airport_graph = AirportDataProvider().load_dataset(airport_datasets[0])\n",
92-
"\n",
93151
"# Call the embedding method with the graph for initialization\n",
94-
"struc2vec = EmbMethodClass(brazil_airport_graph, \n",
152+
"struc2vec = EmbMethodClass(dataset_graph, \n",
95153
" num_walks=10, \n",
96154
" walk_length=80, \n",
97155
" window_size=10, \n",
@@ -106,7 +164,7 @@
106164
},
107165
{
108166
"cell_type": "code",
109-
"execution_count": 27,
167+
"execution_count": 29,
110168
"metadata": {},
111169
"outputs": [
112170
{
@@ -126,13 +184,14 @@
126184
" 'opt3': False}"
127185
]
128186
},
129-
"execution_count": 27,
187+
"execution_count": 29,
130188
"metadata": {},
131189
"output_type": "execute_result"
132190
}
133191
],
134192
"source": [
135-
"# This shows the tunable parameters for the certain embedding method\n",
193+
"# This shows the tunable hyper-parameters for the certain embedding method\n",
194+
"# Here, for example, list the tunable hyper-parameters for struc2vec\n",
136195
"EmbMethodClass.__PARAMS__"
137196
]
138197
},
@@ -145,7 +204,7 @@
145204
},
146205
{
147206
"cell_type": "code",
148-
"execution_count": 28,
207+
"execution_count": 24,
149208
"metadata": {},
150209
"outputs": [],
151210
"source": [
@@ -163,7 +222,7 @@
163222
},
164223
{
165224
"cell_type": "code",
166-
"execution_count": 29,
225+
"execution_count": 25,
167226
"metadata": {},
168227
"outputs": [
169228
{
@@ -185,45 +244,45 @@
185244
},
186245
{
187246
"cell_type": "code",
188-
"execution_count": 31,
247+
"execution_count": 26,
189248
"metadata": {},
190249
"outputs": [
191250
{
192251
"data": {
193252
"text/plain": [
194-
"{'overall': {'accuracy': {'mean': 0.7633, 'std': 0.0787},\n",
195-
" 'f1_macro': {'mean': 0.7548, 'std': 0.0765},\n",
196-
" 'f1_micro': {'mean': 0.7633, 'std': 0.0787},\n",
197-
" 'auc_micro': {'mean': 0.9182, 'std': 0.0327},\n",
198-
" 'auc_macro': {'mean': 0.9224, 'std': 0.0301}},\n",
199-
" 'detailed': {0: {'accuracy': 0.7778,\n",
200-
" 'f1_macro': 0.7515,\n",
201-
" 'f1_micro': 0.7778,\n",
202-
" 'auc_micro': 0.9204,\n",
203-
" 'auc_macro': 0.9298},\n",
204-
" 1: {'accuracy': 0.6154,\n",
205-
" 'f1_macro': 0.6209,\n",
206-
" 'f1_micro': 0.6154,\n",
207-
" 'auc_micro': 0.858,\n",
208-
" 'auc_macro': 0.866},\n",
253+
"{'overall': {'accuracy': {'mean': 0.786, 'std': 0.0759},\n",
254+
" 'f1_macro': {'mean': 0.7791, 'std': 0.0752},\n",
255+
" 'f1_micro': {'mean': 0.786, 'std': 0.0759},\n",
256+
" 'auc_micro': {'mean': 0.9288, 'std': 0.0255},\n",
257+
" 'auc_macro': {'mean': 0.9413, 'std': 0.0182}},\n",
258+
" 'detailed': {0: {'accuracy': 0.8148,\n",
259+
" 'f1_macro': 0.805,\n",
260+
" 'f1_micro': 0.8148,\n",
261+
" 'auc_micro': 0.9374,\n",
262+
" 'auc_macro': 0.9418},\n",
263+
" 1: {'accuracy': 0.6538,\n",
264+
" 'f1_macro': 0.6542,\n",
265+
" 'f1_micro': 0.6538,\n",
266+
" 'auc_micro': 0.8817,\n",
267+
" 'auc_macro': 0.9083},\n",
209268
" 2: {'accuracy': 0.7692,\n",
210269
" 'f1_macro': 0.7448,\n",
211270
" 'f1_micro': 0.7692,\n",
212-
" 'auc_micro': 0.9413,\n",
213-
" 'auc_macro': 0.926},\n",
214-
" 3: {'accuracy': 0.8462,\n",
215-
" 'f1_macro': 0.8421,\n",
216-
" 'f1_micro': 0.8462,\n",
217-
" 'auc_micro': 0.9527,\n",
218-
" 'auc_macro': 0.9561},\n",
271+
" 'auc_micro': 0.9438,\n",
272+
" 'auc_macro': 0.9578},\n",
273+
" 3: {'accuracy': 0.8846,\n",
274+
" 'f1_macro': 0.8769,\n",
275+
" 'f1_micro': 0.8846,\n",
276+
" 'auc_micro': 0.9556,\n",
277+
" 'auc_macro': 0.9585},\n",
219278
" 4: {'accuracy': 0.8077,\n",
220279
" 'f1_macro': 0.8148,\n",
221280
" 'f1_micro': 0.8077,\n",
222-
" 'auc_micro': 0.9186,\n",
223-
" 'auc_macro': 0.9339}}}"
281+
" 'auc_micro': 0.9255,\n",
282+
" 'auc_macro': 0.9401}}}"
224283
]
225284
},
226-
"execution_count": 31,
285+
"execution_count": 26,
227286
"metadata": {},
228287
"output_type": "execute_result"
229288
}
@@ -241,7 +300,7 @@
241300
},
242301
{
243302
"cell_type": "code",
244-
"execution_count": 32,
303+
"execution_count": 27,
245304
"metadata": {},
246305
"outputs": [
247306
{
@@ -255,10 +314,10 @@
255314
{
256315
"data": {
257316
"text/plain": [
258-
"{'overall': {'purity': [0.6412213740458015], 'nmi': [0.4771373196787525]}}"
317+
"{'overall': {'purity': [0.6793893129770993], 'nmi': [0.4854751062047489]}}"
259318
]
260319
},
261-
"execution_count": 32,
320+
"execution_count": 27,
262321
"metadata": {},
263322
"output_type": "execute_result"
264323
}
@@ -276,12 +335,23 @@
276335
},
277336
{
278337
"cell_type": "code",
279-
"execution_count": null,
338+
"execution_count": 28,
280339
"metadata": {},
281-
"outputs": [],
340+
"outputs": [
341+
{
342+
"data": {
343+
"text/plain": [
344+
"0.9379255572546902"
345+
]
346+
},
347+
"execution_count": 28,
348+
"metadata": {},
349+
"output_type": "execute_result"
350+
}
351+
],
282352
"source": [
283353
"from semb.evaluations.centrality_correlation import *\n",
284-
"centrality_correlation(brazil_airport_graph, \n",
354+
"centrality_correlation(dataset_graph, \n",
285355
" dict_struc2vec_emb, \n",
286356
" centrality='clustering_coeff', \n",
287357
" similarity='euclidean')"

semb/.DS_Store

0 Bytes
Binary file not shown.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from semb.datasets import BaseDataset, DatasetInfo
2+
3+
import os
4+
import networkx as nx
5+
from typing import List
6+
7+
# TODO: Make this a remote URL in the future
8+
SAMPLE_DATA_DIR = os.path.join(os.path.dirname(__file__), "../../../sample-data/BlogCatalog")
9+
10+
class Dataset(BaseDataset):
11+
12+
def get_id(self) -> str:
13+
return 'BlogCatalog'
14+
15+
def get_datasets(self) -> List[DatasetInfo]:
16+
return [
17+
DatasetInfo(name="BlogCatalog", description="BlogCatalog data", \
18+
src_url=f'{SAMPLE_DATA_DIR}/BlogCatalog.edgelist')]
19+
20+
def load_dataset(self, dataset: DatasetInfo, directed=False, weighted=False) -> nx.Graph:
21+
if weighted:
22+
graph = nx.read_edgelist(
23+
dataset.src_url,
24+
nodetype=int,
25+
data=(('weight', 'data')),
26+
create_using=(nx.Graph() if not directed else nx.DiGraph()))
27+
else:
28+
graph = nx.read_edgelist(
29+
dataset.src_url,
30+
nodetype=int,
31+
create_using=(nx.Graph() if not directed else nx.DiGraph()))
32+
for edge in graph.edges():
33+
graph[edge[0]][edge[1]]['weight'] = 1
34+
35+
return graph
36+

semb/datasets/DD6/dataset.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from semb.datasets import BaseDataset, DatasetInfo
2+
3+
import os
4+
import networkx as nx
5+
from typing import List
6+
7+
# TODO: Make this a remote URL in the future
8+
SAMPLE_DATA_DIR = os.path.join(os.path.dirname(__file__), "../../../sample-data/DD6")
9+
10+
class Dataset(BaseDataset):
11+
12+
def get_id(self) -> str:
13+
return 'DD6'
14+
15+
def get_datasets(self) -> List[DatasetInfo]:
16+
return [
17+
DatasetInfo(name="DD6", description="DD6 dataset", \
18+
src_url=f'{SAMPLE_DATA_DIR}/DD6.edgelist')]
19+
20+
def load_dataset(self, dataset: DatasetInfo, directed=False, weighted=False) -> nx.Graph:
21+
if weighted:
22+
graph = nx.read_edgelist(
23+
dataset.src_url,
24+
nodetype=int,
25+
data=(('weight', 'data')),
26+
create_using=(nx.Graph() if not directed else nx.DiGraph()))
27+
else:
28+
graph = nx.read_edgelist(
29+
dataset.src_url,
30+
nodetype=int,
31+
create_using=(nx.Graph() if not directed else nx.DiGraph()))
32+
for edge in graph.edges():
33+
graph[edge[0]][edge[1]]['weight'] = 1
34+
35+
return graph
36+

semb/datasets/Facebook/dataset.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from semb.datasets import BaseDataset, DatasetInfo
2+
3+
import os
4+
import networkx as nx
5+
from typing import List
6+
7+
# TODO: Make this a remote URL in the future
8+
SAMPLE_DATA_DIR = os.path.join(os.path.dirname(__file__), "../../../sample-data/Facebook")
9+
10+
class Dataset(BaseDataset):
11+
12+
def get_id(self) -> str:
13+
return 'Facebook'
14+
15+
def get_datasets(self) -> List[DatasetInfo]:
16+
return [
17+
DatasetInfo(name="Facebook", description="Facebook dataset", \
18+
src_url=f'{SAMPLE_DATA_DIR}/Facebook.edgelist')]
19+
20+
def load_dataset(self, dataset: DatasetInfo, directed=False, weighted=False) -> nx.Graph:
21+
if weighted:
22+
graph = nx.read_edgelist(
23+
dataset.src_url,
24+
nodetype=int,
25+
data=(('weight', 'data')),
26+
create_using=(nx.Graph() if not directed else nx.DiGraph()))
27+
else:
28+
graph = nx.read_edgelist(
29+
dataset.src_url,
30+
nodetype=int,
31+
create_using=(nx.Graph() if not directed else nx.DiGraph()))
32+
for edge in graph.edges():
33+
graph[edge[0]][edge[1]]['weight'] = 1
34+
35+
return graph
36+

0 commit comments

Comments
 (0)