11import math
2+ import copy
23import numpy as np
34import pandas as pd
45import networkx as nx
@@ -18,17 +19,45 @@ class Method(BaseMethod):
1819
1920 __PARAMS__ = dict (window_size = 5 , walk_number = 10 , walk_length = 80 , sampling = "first" , p = 1.0 , q = 1.0 , \
2021 dim = 128 , down_sampling = 0.001 , alpha = 0.025 , min_alpha = 0.025 , min_count = 1 , workers = 1 , \
21- epochs = 10 , features = 'wl ' , label_iterations = 2 , log_base = 1.5 , graphlet_size = 4 , \
22+ epochs = 10 , features = 'degree ' , label_iterations = 2 , log_base = 1.5 , graphlet_size = 4 , \
2223 quantiles = 5 , motif_compression = 'string' , seed = 42 , factors = 8 , clusters = 50 , beta = 0.01 )
2324
2425 def get_id (self ):
2526 return "role2vec"
2627
2728 def train (self ):
29+ self .old_graph = copy .deepcopy (self .graph )
30+ old_edges = [e for e in self .old_graph .edges ()]
31+ dict_node_o2n = dict ()
32+ list_edges = list ()
33+
34+ count = 0
35+ for edge in old_edges :
36+ src = int (edge [0 ])
37+ dst = int (edge [1 ])
38+
39+ if src not in dict_node_o2n :
40+ dict_node_o2n [src ] = count
41+ count += 1
42+ if dst not in dict_node_o2n :
43+ dict_node_o2n [dst ] = count
44+ count += 1
45+ list_edges += [(dict_node_o2n [src ], dict_node_o2n [dst ])]
46+
47+ G = nx .Graph ()
48+ G .add_edges_from (list_edges )
49+ self .graph = G
50+
2851 self .do_walks ()
2952 self .create_structural_features ()
3053 self .pooled_features = self .create_pooled_features ()
31- self .embedding = self .create_embedding ()
54+ reps = self .create_embedding ()
55+
56+ dict_n2o = {val : key for key , val in dict_node_o2n .items ()}
57+ self .embeddings = dict ()
58+ for i , node in enumerate (self .graph .nodes ()):
59+ self .embeddings [dict_n2o [node ]] = reps [i ].tolist ()
60+
3261
3362 def do_walks (self ):
3463 """
0 commit comments