|
5 | 5 | import time |
6 | 6 | import os |
7 | 7 | import tensorflow as tf |
| 8 | +import copy |
| 9 | +import networkx as nx |
8 | 10 |
|
| 11 | +folder_eni = os.path.join(os.path.dirname(__file__), "eni/") |
9 | 12 |
|
10 | 13 | class Method(BaseMethod): |
11 | 14 |
|
12 | | - __PARAMS__ = dict(embedding_path=16, epochs_to_train=20, batch_size=16, learning_rate=0.0025, |
| 15 | + __PARAMS__ = dict(embedding_size=128, epochs_to_train=20, batch_size=256, learning_rate=0.0025, |
13 | 16 | alpha=0.0, lamb=0.5, grad_clip=5.0, k=1, sampling_size=100, |
14 | | - seed=1, index_from_0=True, train_device='cpu', save_path=os.getcwd()) |
| 17 | + seed=1, index_from_0=True, train_device='cpu', save_path=folder_eni, save_suffix='eni') |
15 | 18 |
|
16 | 19 | def get_id(self): |
17 | 20 | return "drne" |
18 | 21 |
|
19 | 22 | def train(self): |
20 | 23 | np.random.seed(int(time.time()) |
21 | 24 | if self.params['seed'] == -1 else self.params['seed']) |
| 25 | + |
| 26 | + self.old_graph = copy.deepcopy(self.graph) |
| 27 | + old_edges = [e for e in self.old_graph.edges()] |
| 28 | + dict_node_o2n = dict() |
| 29 | + list_edges = list() |
| 30 | + |
| 31 | + count = 1 |
| 32 | + for edge in old_edges: |
| 33 | + src = int(edge[0]) |
| 34 | + dst = int(edge[1]) |
| 35 | + |
| 36 | + if src not in dict_node_o2n: |
| 37 | + dict_node_o2n[src] = count |
| 38 | + count += 1 |
| 39 | + if dst not in dict_node_o2n: |
| 40 | + dict_node_o2n[dst] = count |
| 41 | + count += 1 |
| 42 | + list_edges += [(dict_node_o2n[src], dict_node_o2n[dst])] |
| 43 | + |
| 44 | + graph = [[]] |
| 45 | + for _ in range(len(dict_node_o2n)): |
| 46 | + graph += [[]] |
| 47 | + |
| 48 | + for cur_edge in list_edges: |
| 49 | + src = cur_edge[0] |
| 50 | + dst = cur_edge[1] |
| 51 | + if dst not in graph[src]: |
| 52 | + graph[src] += [dst] |
| 53 | + if src not in graph[dst]: |
| 54 | + graph[dst] += [src] |
| 55 | + |
| 56 | + self.graph = graph |
| 57 | + |
22 | 58 | network.sort_graph_by_degree(self.graph) |
23 | 59 | config = tf.ConfigProto(allow_soft_placement=True) |
24 | 60 | config.gpu_options.allow_growth = True |
25 | 61 | with tf.Graph().as_default(), tf.Session(config=config) as sess, tf.device(self.params['train_device']): |
26 | 62 | alg = eni(self.graph, self.params, sess) |
27 | | - print("max degree: {}".format(alg.degree_max)) |
| 63 | + # print("max degree: {}".format(alg.degree_max)) |
28 | 64 | alg.train() |
29 | | - self.embeddings = alg.get_embeddings() |
| 65 | + reps = alg.get_embeddings() |
| 66 | + |
| 67 | + dict_n2o = {val: key for key, val in dict_node_o2n.items()} |
| 68 | + self.embeddings = dict() |
| 69 | + for i in range(0, reps.shape[0]): |
| 70 | + self.embeddings[dict_n2o[i+1]] = reps[i].tolist() |
0 commit comments