|
9 | 9 | }, |
10 | 10 | { |
11 | 11 | "cell_type": "code", |
12 | | - "execution_count": 3, |
| 12 | + "execution_count": 18, |
13 | 13 | "metadata": {}, |
14 | 14 | "outputs": [ |
15 | 15 | { |
|
40 | 40 | " load_method(mid)" |
41 | 41 | ] |
42 | 42 | }, |
| 43 | + { |
| 44 | + "cell_type": "markdown", |
| 45 | + "metadata": {}, |
| 46 | + "source": [ |
| 47 | + "These are the method_id for the existing datasets." |
| 48 | + ] |
| 49 | + }, |
43 | 50 | { |
44 | 51 | "cell_type": "code", |
45 | | - "execution_count": 4, |
| 52 | + "execution_count": 19, |
46 | 53 | "metadata": {}, |
47 | 54 | "outputs": [ |
48 | 55 | { |
49 | 56 | "name": "stdout", |
50 | 57 | "output_type": "stream", |
51 | 58 | "text": [ |
| 59 | + "BlogCatalog\n", |
| 60 | + "ICEWS\n", |
| 61 | + "Facebook\n", |
| 62 | + "DD6\n", |
| 63 | + "PPI\n", |
52 | 64 | "airports\n" |
53 | 65 | ] |
54 | 66 | } |
|
61 | 73 | " load_dataset(did)" |
62 | 74 | ] |
63 | 75 | }, |
| 76 | + { |
| 77 | + "cell_type": "markdown", |
| 78 | + "metadata": {}, |
| 79 | + "source": [ |
| 80 | + "These are the dataset_id for the existing datasets." |
| 81 | + ] |
| 82 | + }, |
| 83 | + { |
| 84 | + "cell_type": "markdown", |
| 85 | + "metadata": {}, |
| 86 | + "source": [ |
| 87 | + "## Load Dataset" |
| 88 | + ] |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "code", |
| 92 | + "execution_count": 20, |
| 93 | + "metadata": {}, |
| 94 | + "outputs": [], |
| 95 | + "source": [ |
| 96 | + "# Get airports datasets\n", |
| 97 | + "DataProvider = load_dataset(\"airports\")\n", |
| 98 | + "Datasets = DataProvider().get_datasets()\n", |
| 99 | + "dataset_graph = DataProvider().load_dataset(Datasets[0])" |
| 100 | + ] |
| 101 | + }, |
| 102 | + { |
| 103 | + "cell_type": "markdown", |
| 104 | + "metadata": {}, |
| 105 | + "source": [ |
| 106 | + "Note that there are three datasets in the airports dataset.\n", |
| 107 | + "\n", |
| 108 | + "Datasets\\[0\\] represents the BR-air traffic Dataset\n", |
| 109 | + "\n", |
| 110 | + "Datasets\\[1\\] represents the EU-air traffic Dataset\n", |
| 111 | + "\n", |
| 112 | + "Datasets\\[2\\] represents the US-air traffic Dataset" |
| 113 | + ] |
| 114 | + }, |
| 115 | + { |
| 116 | + "cell_type": "code", |
| 117 | + "execution_count": 21, |
| 118 | + "metadata": {}, |
| 119 | + "outputs": [], |
| 120 | + "source": [ |
| 121 | + "# Example code for getting the other datasets\n", |
| 122 | + "DataProvider = load_dataset(\"Facebook\")\n", |
| 123 | + "Facebook_dataset = DataProvider().get_datasets()\n", |
| 124 | + "Facebook_graph = DataProvider().load_dataset(Facebook_dataset[0])" |
| 125 | + ] |
| 126 | + }, |
64 | 127 | { |
65 | 128 | "cell_type": "markdown", |
66 | 129 | "metadata": {}, |
|
85 | 148 | "# Define a hyper-class to load the embedding method\n", |
86 | 149 | "EmbMethodClass = load_method(\"struc2vec\")\n", |
87 | 150 | "\n", |
88 | | - "# Get airports datasets\n", |
89 | | - "AirportDataProvider = load_dataset(\"airports\")\n", |
90 | | - "airport_datasets = AirportDataProvider().get_datasets()\n", |
91 | | - "brazil_airport_graph = AirportDataProvider().load_dataset(airport_datasets[0])\n", |
92 | | - "\n", |
93 | 151 | "# Call the embedding method with the graph for initialization\n", |
94 | | - "struc2vec = EmbMethodClass(brazil_airport_graph, \n", |
| 152 | + "struc2vec = EmbMethodClass(dataset_graph, \n", |
95 | 153 | " num_walks=10, \n", |
96 | 154 | " walk_length=80, \n", |
97 | 155 | " window_size=10, \n", |
|
106 | 164 | }, |
107 | 165 | { |
108 | 166 | "cell_type": "code", |
109 | | - "execution_count": 27, |
| 167 | + "execution_count": 29, |
110 | 168 | "metadata": {}, |
111 | 169 | "outputs": [ |
112 | 170 | { |
|
126 | 184 | " 'opt3': False}" |
127 | 185 | ] |
128 | 186 | }, |
129 | | - "execution_count": 27, |
| 187 | + "execution_count": 29, |
130 | 188 | "metadata": {}, |
131 | 189 | "output_type": "execute_result" |
132 | 190 | } |
133 | 191 | ], |
134 | 192 | "source": [ |
135 | | - "# This shows the tunable parameters for the certain embedding method\n", |
| 193 | + "# This shows the tunable hyper-parameters for the certain embedding method\n", |
| 194 | + "# Here, for example, list the tunable hyper-parameters for struc2vec\n", |
136 | 195 | "EmbMethodClass.__PARAMS__" |
137 | 196 | ] |
138 | 197 | }, |
|
145 | 204 | }, |
146 | 205 | { |
147 | 206 | "cell_type": "code", |
148 | | - "execution_count": 28, |
| 207 | + "execution_count": 24, |
149 | 208 | "metadata": {}, |
150 | 209 | "outputs": [], |
151 | 210 | "source": [ |
|
163 | 222 | }, |
164 | 223 | { |
165 | 224 | "cell_type": "code", |
166 | | - "execution_count": 29, |
| 225 | + "execution_count": 25, |
167 | 226 | "metadata": {}, |
168 | 227 | "outputs": [ |
169 | 228 | { |
|
185 | 244 | }, |
186 | 245 | { |
187 | 246 | "cell_type": "code", |
188 | | - "execution_count": 31, |
| 247 | + "execution_count": 26, |
189 | 248 | "metadata": {}, |
190 | 249 | "outputs": [ |
191 | 250 | { |
192 | 251 | "data": { |
193 | 252 | "text/plain": [ |
194 | | - "{'overall': {'accuracy': {'mean': 0.7633, 'std': 0.0787},\n", |
195 | | - " 'f1_macro': {'mean': 0.7548, 'std': 0.0765},\n", |
196 | | - " 'f1_micro': {'mean': 0.7633, 'std': 0.0787},\n", |
197 | | - " 'auc_micro': {'mean': 0.9182, 'std': 0.0327},\n", |
198 | | - " 'auc_macro': {'mean': 0.9224, 'std': 0.0301}},\n", |
199 | | - " 'detailed': {0: {'accuracy': 0.7778,\n", |
200 | | - " 'f1_macro': 0.7515,\n", |
201 | | - " 'f1_micro': 0.7778,\n", |
202 | | - " 'auc_micro': 0.9204,\n", |
203 | | - " 'auc_macro': 0.9298},\n", |
204 | | - " 1: {'accuracy': 0.6154,\n", |
205 | | - " 'f1_macro': 0.6209,\n", |
206 | | - " 'f1_micro': 0.6154,\n", |
207 | | - " 'auc_micro': 0.858,\n", |
208 | | - " 'auc_macro': 0.866},\n", |
| 253 | + "{'overall': {'accuracy': {'mean': 0.786, 'std': 0.0759},\n", |
| 254 | + " 'f1_macro': {'mean': 0.7791, 'std': 0.0752},\n", |
| 255 | + " 'f1_micro': {'mean': 0.786, 'std': 0.0759},\n", |
| 256 | + " 'auc_micro': {'mean': 0.9288, 'std': 0.0255},\n", |
| 257 | + " 'auc_macro': {'mean': 0.9413, 'std': 0.0182}},\n", |
| 258 | + " 'detailed': {0: {'accuracy': 0.8148,\n", |
| 259 | + " 'f1_macro': 0.805,\n", |
| 260 | + " 'f1_micro': 0.8148,\n", |
| 261 | + " 'auc_micro': 0.9374,\n", |
| 262 | + " 'auc_macro': 0.9418},\n", |
| 263 | + " 1: {'accuracy': 0.6538,\n", |
| 264 | + " 'f1_macro': 0.6542,\n", |
| 265 | + " 'f1_micro': 0.6538,\n", |
| 266 | + " 'auc_micro': 0.8817,\n", |
| 267 | + " 'auc_macro': 0.9083},\n", |
209 | 268 | " 2: {'accuracy': 0.7692,\n", |
210 | 269 | " 'f1_macro': 0.7448,\n", |
211 | 270 | " 'f1_micro': 0.7692,\n", |
212 | | - " 'auc_micro': 0.9413,\n", |
213 | | - " 'auc_macro': 0.926},\n", |
214 | | - " 3: {'accuracy': 0.8462,\n", |
215 | | - " 'f1_macro': 0.8421,\n", |
216 | | - " 'f1_micro': 0.8462,\n", |
217 | | - " 'auc_micro': 0.9527,\n", |
218 | | - " 'auc_macro': 0.9561},\n", |
| 271 | + " 'auc_micro': 0.9438,\n", |
| 272 | + " 'auc_macro': 0.9578},\n", |
| 273 | + " 3: {'accuracy': 0.8846,\n", |
| 274 | + " 'f1_macro': 0.8769,\n", |
| 275 | + " 'f1_micro': 0.8846,\n", |
| 276 | + " 'auc_micro': 0.9556,\n", |
| 277 | + " 'auc_macro': 0.9585},\n", |
219 | 278 | " 4: {'accuracy': 0.8077,\n", |
220 | 279 | " 'f1_macro': 0.8148,\n", |
221 | 280 | " 'f1_micro': 0.8077,\n", |
222 | | - " 'auc_micro': 0.9186,\n", |
223 | | - " 'auc_macro': 0.9339}}}" |
| 281 | + " 'auc_micro': 0.9255,\n", |
| 282 | + " 'auc_macro': 0.9401}}}" |
224 | 283 | ] |
225 | 284 | }, |
226 | | - "execution_count": 31, |
| 285 | + "execution_count": 26, |
227 | 286 | "metadata": {}, |
228 | 287 | "output_type": "execute_result" |
229 | 288 | } |
|
241 | 300 | }, |
242 | 301 | { |
243 | 302 | "cell_type": "code", |
244 | | - "execution_count": 32, |
| 303 | + "execution_count": 27, |
245 | 304 | "metadata": {}, |
246 | 305 | "outputs": [ |
247 | 306 | { |
|
255 | 314 | { |
256 | 315 | "data": { |
257 | 316 | "text/plain": [ |
258 | | - "{'overall': {'purity': [0.6412213740458015], 'nmi': [0.4771373196787525]}}" |
| 317 | + "{'overall': {'purity': [0.6793893129770993], 'nmi': [0.4854751062047489]}}" |
259 | 318 | ] |
260 | 319 | }, |
261 | | - "execution_count": 32, |
| 320 | + "execution_count": 27, |
262 | 321 | "metadata": {}, |
263 | 322 | "output_type": "execute_result" |
264 | 323 | } |
|
276 | 335 | }, |
277 | 336 | { |
278 | 337 | "cell_type": "code", |
279 | | - "execution_count": null, |
| 338 | + "execution_count": 28, |
280 | 339 | "metadata": {}, |
281 | | - "outputs": [], |
| 340 | + "outputs": [ |
| 341 | + { |
| 342 | + "data": { |
| 343 | + "text/plain": [ |
| 344 | + "0.9379255572546902" |
| 345 | + ] |
| 346 | + }, |
| 347 | + "execution_count": 28, |
| 348 | + "metadata": {}, |
| 349 | + "output_type": "execute_result" |
| 350 | + } |
| 351 | + ], |
282 | 352 | "source": [ |
283 | 353 | "from semb.evaluations.centrality_correlation import *\n", |
284 | | - "centrality_correlation(brazil_airport_graph, \n", |
| 354 | + "centrality_correlation(dataset_graph, \n", |
285 | 355 | " dict_struc2vec_emb, \n", |
286 | 356 | " centrality='clustering_coeff', \n", |
287 | 357 | " similarity='euclidean')" |
|
0 commit comments