Skip to content

Commit 41b0abb

Browse files
committed
add xnetmf
1 parent bd4fe6b commit 41b0abb

5 files changed

Lines changed: 40 additions & 22 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

semb/.DS_Store

0 Bytes
Binary file not shown.

semb/methods/xnetmf/internal/xnetmf.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,20 +204,20 @@ def get_representations(graph, rep_method, verbose = True):
204204
reprsn = reprsn / np.linalg.norm(reprsn, axis = 1).reshape((reprsn.shape[0],1))
205205
return reprsn
206206

207-
if __name__ == "__main__":
208-
if len(sys.argv) < 2:
209-
#####PUT IN YOUR GRAPH AS AN EDGELIST HERE (or pass as cmd line argument)#####
210-
#(see networkx read_edgelist() method...if networkx can read your file as an edgelist you're good!)
211-
graph_file = "data/arenas_combined_edges.txt"
212-
else:
213-
graph_file = sys.argv[1]
214-
nx_graph = nx.read_edgelist(graph_file, nodetype = int, comments="%")
215-
adj_matrix = nx.adjacency_matrix(nx_graph).todense()
207+
# if __name__ == "__main__":
208+
# if len(sys.argv) < 2:
209+
# #####PUT IN YOUR GRAPH AS AN EDGELIST HERE (or pass as cmd line argument)#####
210+
# #(see networkx read_edgelist() method...if networkx can read your file as an edgelist you're good!)
211+
# graph_file = "data/arenas_combined_edges.txt"
212+
# else:
213+
# graph_file = sys.argv[1]
214+
# nx_graph = nx.read_edgelist(graph_file, nodetype = int, comments="%")
215+
# adj_matrix = nx.adjacency_matrix(nx_graph).todense()
216216

217-
graph = Graph(adj_matrix)
218-
rep_method = RepMethod(max_layer = 2) #Learn representations with xNetMF. Can adjust parameters (e.g. as in REGAL)
219-
representations = get_representations(graph, rep_method)
220-
print(representations.shape)
217+
# graph = Graph(adj_matrix)
218+
# rep_method = RepMethod(max_layer = 2) #Learn representations with xNetMF. Can adjust parameters (e.g. as in REGAL)
219+
# representations = get_representations(graph, rep_method)
220+
# print(representations.shape)
221221

222222

223223

semb/methods/xnetmf/method.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ def train(self):
1313
# learn representations with xNetMF. Can adjust parameters (e.g. as in REGAL)
1414
rep_method = RepMethod(max_layer=self.params['max_layer'], p=self.params['dim'],
1515
alpha=self.params['discount'], gammastruc=self.params['gamma'])
16-
# FIXME: this doesnt look like a standard embeddings format
17-
representations = get_representations(self.graph, rep_method)
18-
list_nodes = self.graph.node()
16+
graph = Graph(nx.adjacency_matrix(self.graph))
17+
representations = get_representations(graph, rep_method)
18+
1919
self.embeddings = dict()
20-
for i in range(0, len(list_nodes)):
21-
self.embeddings[i] = representations[i].tolist()
20+
list_nodes = list(self.graph.nodes())
2221

22+
for i in range(0, representations.shape[0]):
23+
self.embeddings[list_nodes[i]] = representations[i].tolist()

test.ipynb

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
},
107107
{
108108
"cell_type": "code",
109-
"execution_count": 3,
109+
"execution_count": 4,
110110
"metadata": {
111111
"tags": []
112112
},
@@ -115,20 +115,37 @@
115115
"name": "stdout",
116116
"output_type": "stream",
117117
"text": [
118-
"rm /Users/mark/GoogleDrive/UM/S4/GEMS/Git/StrucEmbeddingLibrary/semb/methods/struc2vec/pickles/weights_distances-layer-*.pickle\n"
118+
"max degree: 80\n",
119+
"got k hop neighbors in time: 0.0337069034576416\n",
120+
"got degree sequences in time: 0.013570070266723633\n",
121+
"computed representation in time: 0.010637044906616211\n",
122+
"[[-1.37432577e-04 1.30331318e-03 -2.69511558e-04 ... -2.44950932e-22\n",
123+
" 1.63093975e-22 1.42838586e-22]\n",
124+
" [-9.78507622e-05 1.49474244e-04 -2.42289473e-04 ... -2.88146479e-22\n",
125+
" 2.12115482e-22 1.59210219e-22]\n",
126+
" [-1.10454466e-05 -3.63693553e-05 1.05660366e-04 ... -2.95114433e-22\n",
127+
" 2.13269839e-22 1.64978283e-22]\n",
128+
" ...\n",
129+
" [-5.87339184e-05 -1.25072095e-04 -8.59305162e-06 ... -5.85895127e-22\n",
130+
" 4.11111734e-22 1.68655414e-22]\n",
131+
" [-5.15959810e-05 -1.13438423e-04 -1.00442930e-06 ... -5.89756433e-22\n",
132+
" 4.14443476e-22 1.70109439e-22]\n",
133+
" [ 2.61171369e-04 -9.51229394e-05 -4.38703782e-05 ... -5.83886813e-22\n",
134+
" 4.11219006e-22 1.66800475e-22]]\n",
135+
"<semb.methods.xnetmf.internal.config.Graph object at 0x7fa67c892fd0>\n"
119136
]
120137
}
121138
],
122139
"source": [
123140
"from semb.methods import load as load_method\n",
124141
"from semb.datasets import load as load_dataset\n",
125142
"\n",
126-
"Node2VecMethod = load_method(\"struc2vec\")\n",
143+
"Node2VecMethod = load_method(\"xnetmf\")\n",
127144
"AirportDataProvider = load_dataset(\"airports\")\n",
128145
"airport_datasets = AirportDataProvider().get_datasets()\n",
129146
"brazil_airport_graph = AirportDataProvider().load_dataset(airport_datasets[0])\n",
130147
"\n",
131-
"node2vec = Node2VecMethod(brazil_airport_graph, opt1=True, opt2=True, opt3=True)\n",
148+
"node2vec = Node2VecMethod(brazil_airport_graph)\n",
132149
"node2vec.train()\n",
133150
"emb = node2vec.get_embeddings()"
134151
]

0 commit comments

Comments
 (0)