Skip to content

Commit 2318599

Browse files
committed
fix SEGK
1 parent 41b0abb commit 2318599

5 files changed

Lines changed: 197 additions & 16850 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

Untitled1.ipynb

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"graphwave\n",
13+
"degree2\n",
14+
"drne\n"
15+
]
16+
},
17+
{
18+
"name": "stderr",
19+
"output_type": "stream",
20+
"text": [
21+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:493: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
22+
" _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
23+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:494: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
24+
" _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
25+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:495: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
26+
" _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
27+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:496: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
28+
" _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
29+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:497: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
30+
" _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
31+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:502: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
32+
" np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
33+
]
34+
},
35+
{
36+
"name": "stdout",
37+
"output_type": "stream",
38+
"text": [
39+
"node2vec\n",
40+
"degree\n",
41+
"role2vec\n",
42+
"line\n",
43+
"degree1\n",
44+
"struc2vec\n",
45+
"xnetmf\n",
46+
"multilens\n",
47+
"segk\n"
48+
]
49+
},
50+
{
51+
"name": "stderr",
52+
"output_type": "stream",
53+
"text": [
54+
"/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/sklearn/externals/joblib/__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n",
55+
" warnings.warn(msg, category=DeprecationWarning)\n"
56+
]
57+
},
58+
{
59+
"name": "stdout",
60+
"output_type": "stream",
61+
"text": [
62+
"riwalk\n"
63+
]
64+
}
65+
],
66+
"source": [
67+
"from semb.methods import load as load_method\n",
68+
"from semb.methods import get_method_ids\n",
69+
"for mid in get_method_ids():\n",
70+
" print(mid)\n",
71+
" load_method(mid)"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": 2,
77+
"metadata": {},
78+
"outputs": [
79+
{
80+
"name": "stdout",
81+
"output_type": "stream",
82+
"text": [
83+
"airports\n"
84+
]
85+
}
86+
],
87+
"source": [
88+
"from semb.datasets import load as load_dataset\n",
89+
"from semb.datasets import get_dataset_ids\n",
90+
"for did in get_dataset_ids():\n",
91+
" print(did)\n",
92+
" load_dataset(did)"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": 3,
98+
"metadata": {},
99+
"outputs": [],
100+
"source": [
101+
"from semb.methods import load as load_method\n",
102+
"from semb.datasets import load as load_dataset\n",
103+
"\n",
104+
"Node2VecMethod = load_method(\"segk\")\n",
105+
"AirportDataProvider = load_dataset(\"airports\")\n",
106+
"airport_datasets = AirportDataProvider().get_datasets()\n",
107+
"brazil_airport_graph = AirportDataProvider().load_dataset(airport_datasets[0])\n",
108+
"\n",
109+
"node2vec = Node2VecMethod(brazil_airport_graph)\n",
110+
"node2vec.train()\n",
111+
"emb = node2vec.get_embeddings()"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {},
118+
"outputs": [],
119+
"source": []
120+
}
121+
],
122+
"metadata": {
123+
"kernelspec": {
124+
"display_name": "SEMB",
125+
"language": "python",
126+
"name": "semb"
127+
},
128+
"language_info": {
129+
"codemirror_mode": {
130+
"name": "ipython",
131+
"version": 3
132+
},
133+
"file_extension": ".py",
134+
"mimetype": "text/x-python",
135+
"name": "python",
136+
"nbconvert_exporter": "python",
137+
"pygments_lexer": "ipython3",
138+
"version": "3.6.2"
139+
}
140+
},
141+
"nbformat": 4,
142+
"nbformat_minor": 4
143+
}

semb/.DS_Store

0 Bytes
Binary file not shown.

semb/methods/segk/method.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,20 @@
1010

1111
class Method(BaseMethod):
1212

13-
__PARAMS__ = dict(radius=2, dim=10, kernel='shortest_path')
13+
__PARAMS__ = dict(radius=2, dim=128, kernel='weisfeiler_lehman')
1414

1515
def get_id(self):
1616
return "segk"
1717

1818
def train(self):
19-
self.embeddings = self.segk(
20-
self.graph.nodes, self.graph.edges,
21-
self.params['radius'], self.params['dim'], self.params['kernel'])
19+
nodes = [i for i in self.graph.nodes()]
20+
edges = [e for e in self.graph.edges() if e[0] != e[1]]
21+
reps = self.segk(nodes, edges, self.params['radius'], self.params['dim'], self.params['kernel'])
22+
self.embeddings = dict()
23+
for i, cur_node in enumerate(self.graph.nodes):
24+
self.embeddings[cur_node] = reps[i, :].tolist()
25+
26+
2227

2328
def segk(self, nodes, edgelist, radius, dim, kernel):
2429
n = len(nodes)

0 commit comments

Comments
 (0)