Skip to content

Commit a984113

Browse files
committed
fix example
1 parent 261a4d2 commit a984113

6 files changed

Lines changed: 1622 additions & 13236 deletions

File tree

Example.ipynb

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Load Embedding Methods and Datasets "
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 3,
13+
"metadata": {},
14+
"outputs": [
15+
{
16+
"name": "stdout",
17+
"output_type": "stream",
18+
"text": [
19+
"graphwave\n",
20+
"degree2\n",
21+
"drne\n",
22+
"node2vec\n",
23+
"degree\n",
24+
"role2vec\n",
25+
"line\n",
26+
"degree1\n",
27+
"struc2vec\n",
28+
"xnetmf\n",
29+
"multilens\n",
30+
"segk\n",
31+
"riwalk\n"
32+
]
33+
}
34+
],
35+
"source": [
36+
"from semb.methods import load as load_method\n",
37+
"from semb.methods import get_method_ids\n",
38+
"for mid in get_method_ids():\n",
39+
" print(mid)\n",
40+
" load_method(mid)"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 4,
46+
"metadata": {},
47+
"outputs": [
48+
{
49+
"name": "stdout",
50+
"output_type": "stream",
51+
"text": [
52+
"airports\n"
53+
]
54+
}
55+
],
56+
"source": [
57+
"from semb.datasets import load as load_dataset\n",
58+
"from semb.datasets import get_dataset_ids\n",
59+
"for did in get_dataset_ids():\n",
60+
" print(did)\n",
61+
" load_dataset(did)"
62+
]
63+
},
64+
{
65+
"cell_type": "markdown",
66+
"metadata": {},
67+
"source": [
68+
"## Get Embedding Result Using struc2vec"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": 22,
74+
"metadata": {},
75+
"outputs": [
76+
{
77+
"name": "stdout",
78+
"output_type": "stream",
79+
"text": [
80+
"rm /Users/mark/GoogleDrive/UM/S4/GEMS/Git/StrucEmbeddingLibrary/semb/methods/struc2vec/pickles/weights_distances-layer-*.pickle\n"
81+
]
82+
}
83+
],
84+
"source": [
85+
"# Define a hyper-class to load the embedding method\n",
86+
"EmbMethodClass = load_method(\"struc2vec\")\n",
87+
"\n",
88+
"# Get airports datasets\n",
89+
"AirportDataProvider = load_dataset(\"airports\")\n",
90+
"airport_datasets = AirportDataProvider().get_datasets()\n",
91+
"brazil_airport_graph = AirportDataProvider().load_dataset(airport_datasets[0])\n",
92+
"\n",
93+
"# Call the embedding method with the graph for initialization\n",
94+
"struc2vec = EmbMethodClass(brazil_airport_graph, \n",
95+
" num_walks=10, \n",
96+
" walk_length=80, \n",
97+
" window_size=10, \n",
98+
" dim=128, \n",
99+
" opt1=True, opt2=True, opt3=True, until_layer=2)\n",
100+
"struc2vec.train()\n",
101+
"\n",
102+
"# Get the embedding result with the get_embeddings() method,\n",
103+
"# The return type is a dictionary with key as node_id and value as the embedding\n",
104+
"dict_struc2vec_emb = struc2vec.get_embeddings()"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 27,
110+
"metadata": {},
111+
"outputs": [
112+
{
113+
"data": {
114+
"text/plain": [
115+
"{'dim': 128,\n",
116+
" 'walk_length': 80,\n",
117+
" 'num_walks': 10,\n",
118+
" 'window_size': 10,\n",
119+
" 'until_layer': None,\n",
120+
" 'iter': 5,\n",
121+
" 'workers': 1,\n",
122+
" 'weighted': False,\n",
123+
" 'directed': False,\n",
124+
" 'opt1': False,\n",
125+
" 'opt2': False,\n",
126+
" 'opt3': False}"
127+
]
128+
},
129+
"execution_count": 27,
130+
"metadata": {},
131+
"output_type": "execute_result"
132+
}
133+
],
134+
"source": [
135+
"# This shows the tunable parameters for the certain embedding method\n",
136+
"EmbMethodClass.__PARAMS__"
137+
]
138+
},
139+
{
140+
"cell_type": "markdown",
141+
"metadata": {},
142+
"source": [
143+
"## Load Evaluation Library and Perform Evaluation"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": 28,
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"from semb.evaluations.classification import *\n",
153+
"from semb.evaluations.clustering import *\n",
154+
"from semb.evaluations.utils import *"
155+
]
156+
},
157+
{
158+
"cell_type": "markdown",
159+
"metadata": {},
160+
"source": [
161+
"### Perform Classification"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": 29,
167+
"metadata": {},
168+
"outputs": [
169+
{
170+
"name": "stdout",
171+
"output_type": "stream",
172+
"text": [
173+
"Read in 131 node labels.\n",
174+
">>> Label 0 appears 32 times\n",
175+
">>> Label 1 appears 32 times\n",
176+
">>> Label 3 appears 35 times\n",
177+
">>> Label 2 appears 32 times\n"
178+
]
179+
}
180+
],
181+
"source": [
182+
"# Read the label file with the get_label(fn) function\n",
183+
"dict_labels = get_label(\"./sample-data/labels/airport_Brazil_label.txt\")"
184+
]
185+
},
186+
{
187+
"cell_type": "code",
188+
"execution_count": 31,
189+
"metadata": {},
190+
"outputs": [
191+
{
192+
"data": {
193+
"text/plain": [
194+
"{'overall': {'accuracy': {'mean': 0.7633, 'std': 0.0787},\n",
195+
" 'f1_macro': {'mean': 0.7548, 'std': 0.0765},\n",
196+
" 'f1_micro': {'mean': 0.7633, 'std': 0.0787},\n",
197+
" 'auc_micro': {'mean': 0.9182, 'std': 0.0327},\n",
198+
" 'auc_macro': {'mean': 0.9224, 'std': 0.0301}},\n",
199+
" 'detailed': {0: {'accuracy': 0.7778,\n",
200+
" 'f1_macro': 0.7515,\n",
201+
" 'f1_micro': 0.7778,\n",
202+
" 'auc_micro': 0.9204,\n",
203+
" 'auc_macro': 0.9298},\n",
204+
" 1: {'accuracy': 0.6154,\n",
205+
" 'f1_macro': 0.6209,\n",
206+
" 'f1_micro': 0.6154,\n",
207+
" 'auc_micro': 0.858,\n",
208+
" 'auc_macro': 0.866},\n",
209+
" 2: {'accuracy': 0.7692,\n",
210+
" 'f1_macro': 0.7448,\n",
211+
" 'f1_micro': 0.7692,\n",
212+
" 'auc_micro': 0.9413,\n",
213+
" 'auc_macro': 0.926},\n",
214+
" 3: {'accuracy': 0.8462,\n",
215+
" 'f1_macro': 0.8421,\n",
216+
" 'f1_micro': 0.8462,\n",
217+
" 'auc_micro': 0.9527,\n",
218+
" 'auc_macro': 0.9561},\n",
219+
" 4: {'accuracy': 0.8077,\n",
220+
" 'f1_macro': 0.8148,\n",
221+
" 'f1_micro': 0.8077,\n",
222+
" 'auc_micro': 0.9186,\n",
223+
" 'auc_macro': 0.9339}}}"
224+
]
225+
},
226+
"execution_count": 31,
227+
"metadata": {},
228+
"output_type": "execute_result"
229+
}
230+
],
231+
"source": [
232+
"perform_classification(dict_struc2vec_emb, dict_labels)"
233+
]
234+
},
235+
{
236+
"cell_type": "markdown",
237+
"metadata": {},
238+
"source": [
239+
"### Perform Clustering"
240+
]
241+
},
242+
{
243+
"cell_type": "code",
244+
"execution_count": 32,
245+
"metadata": {},
246+
"outputs": [
247+
{
248+
"name": "stderr",
249+
"output_type": "stream",
250+
"text": [
251+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/sklearn/metrics/cluster/supervised.py:859: FutureWarning: The behavior of NMI will change in version 0.22. To match the behavior of 'v_measure_score', NMI will use average_method='arithmetic' by default.\n",
252+
" FutureWarning)\n"
253+
]
254+
},
255+
{
256+
"data": {
257+
"text/plain": [
258+
"{'overall': {'purity': [0.6412213740458015], 'nmi': [0.4771373196787525]}}"
259+
]
260+
},
261+
"execution_count": 32,
262+
"metadata": {},
263+
"output_type": "execute_result"
264+
}
265+
],
266+
"source": [
267+
"perform_clustering(dict_struc2vec_emb, dict_labels)"
268+
]
269+
},
270+
{
271+
"cell_type": "markdown",
272+
"metadata": {},
273+
"source": [
274+
"## Perform Centrality Correlation"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": null,
280+
"metadata": {},
281+
"outputs": [],
282+
"source": [
283+
"from semb.evaluations.centrality_correlation import *\n",
284+
"centrality_correlation(brazil_airport_graph, \n",
285+
" dict_struc2vec_emb, \n",
286+
" centrality='clustering_coeff', \n",
287+
" similarity='euclidean')"
288+
]
289+
}
290+
],
291+
"metadata": {
292+
"kernelspec": {
293+
"display_name": "SEMB",
294+
"language": "python",
295+
"name": "semb"
296+
},
297+
"language_info": {
298+
"codemirror_mode": {
299+
"name": "ipython",
300+
"version": 3
301+
},
302+
"file_extension": ".py",
303+
"mimetype": "text/x-python",
304+
"name": "python",
305+
"nbconvert_exporter": "python",
306+
"pygments_lexer": "ipython3",
307+
"version": "3.6.2"
308+
}
309+
},
310+
"nbformat": 4,
311+
"nbformat_minor": 4
312+
}

0 commit comments

Comments
 (0)