Skip to content

Commit b3fcb77

Browse files
committed
fix drne
1 parent 5255f84 commit b3fcb77

6 files changed

Lines changed: 168 additions & 293 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,5 @@ cython_debug/
151151
semb/methods/struc2vec/pickles/
152152
semb/methods/struc2vec/random_walks.txt
153153

154+
# drne intermediate files
155+
semb/methods/drne/eni/

semb/.DS_Store

0 Bytes
Binary file not shown.

semb/methods/drne/eni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def train(self):
127127
n += 1
128128
end = time.time()
129129
#process = psutil.Process(os.getpid())
130-
print(("epoch: {}/{}, batch: {}/{}, loss: {:.6f}, structure_loss: {:.6f}, orth_loss: {:.6f}, guilded_loss: {:.6f}, time: {:.4f}s").format(epoch, self.params.epochs_to_train, n-1, total_num, total_loss, structure_loss, orth_loss, guilded_loss, end-begin))
130+
print(("epoch: {}/{}, batch: {}/{}, loss: {:.6f}, structure_loss: {:.6f}, orth_loss: {:.6f}, guilded_loss: {:.6f}, time: {:.4f}s").format(epoch, self.params['epochs_to_train'], n-1, total_num, total_loss, structure_loss, orth_loss, guilded_loss, end-begin))
131131
if num % 5 == 0:
132132
summary_str = self.sess.run(self.merged_summary, feed_dict=batch_data)
133133
self.summary_writer.add_summary(summary_str, num)

semb/methods/drne/method.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,66 @@
55
import time
66
import os
77
import tensorflow as tf
8+
import copy
9+
import networkx as nx
810

11+
folder_eni = os.path.join(os.path.dirname(__file__), "eni/")
912

1013
class Method(BaseMethod):
1114

12-
__PARAMS__ = dict(embedding_path=16, epochs_to_train=20, batch_size=16, learning_rate=0.0025,
15+
__PARAMS__ = dict(embedding_size=128, epochs_to_train=20, batch_size=256, learning_rate=0.0025,
1316
alpha=0.0, lamb=0.5, grad_clip=5.0, k=1, sampling_size=100,
14-
seed=1, index_from_0=True, train_device='cpu', save_path=os.getcwd())
17+
seed=1, index_from_0=True, train_device='cpu', save_path=folder_eni, save_suffix='eni')
1518

1619
def get_id(self):
1720
return "drne"
1821

1922
def train(self):
2023
np.random.seed(int(time.time())
2124
if self.params['seed'] == -1 else self.params['seed'])
25+
26+
self.old_graph = copy.deepcopy(self.graph)
27+
old_edges = [e for e in self.old_graph.edges()]
28+
dict_node_o2n = dict()
29+
list_edges = list()
30+
31+
count = 1
32+
for edge in old_edges:
33+
src = int(edge[0])
34+
dst = int(edge[1])
35+
36+
if src not in dict_node_o2n:
37+
dict_node_o2n[src] = count
38+
count += 1
39+
if dst not in dict_node_o2n:
40+
dict_node_o2n[dst] = count
41+
count += 1
42+
list_edges += [(dict_node_o2n[src], dict_node_o2n[dst])]
43+
44+
graph = [[]]
45+
for _ in range(len(dict_node_o2n)):
46+
graph += [[]]
47+
48+
for cur_edge in list_edges:
49+
src = cur_edge[0]
50+
dst = cur_edge[1]
51+
if dst not in graph[src]:
52+
graph[src] += [dst]
53+
if src not in graph[dst]:
54+
graph[dst] += [src]
55+
56+
self.graph = graph
57+
2258
network.sort_graph_by_degree(self.graph)
2359
config = tf.ConfigProto(allow_soft_placement=True)
2460
config.gpu_options.allow_growth = True
2561
with tf.Graph().as_default(), tf.Session(config=config) as sess, tf.device(self.params['train_device']):
2662
alg = eni(self.graph, self.params, sess)
27-
print("max degree: {}".format(alg.degree_max))
63+
# print("max degree: {}".format(alg.degree_max))
2864
alg.train()
29-
self.embeddings = alg.get_embeddings()
65+
reps = alg.get_embeddings()
66+
67+
dict_n2o = {val: key for key, val in dict_node_o2n.items()}
68+
self.embeddings = dict()
69+
for i in range(0, reps.shape[0]):
70+
self.embeddings[dict_n2o[i+1]] = reps[i].tolist()

0 commit comments

Comments
 (0)