@@ -12,30 +12,94 @@ def get_id(self):
1212 return "multilens"
1313
1414 def train (self ):
15+ directed = True
16+ base_features = ['row' , 'col' , 'row_col' ]
1517 dim = self .params ['dim' ]
1618 L = self .params ['L' ]
1719 num_buckets = self .params ['base' ]
1820 op = self .params ['operators' ]
21+
22+ dict_id_idx = dict ()
23+ dict_idx_id = dict ()
24+
25+ raw_ = list ()
26+ cur_count_ = 0
27+ for cur_edge in self .graph .edges ():
28+ src = cur_edge [0 ]
29+ dst = cur_edge [1 ]
30+ if src not in dict_id_idx :
31+ dict_id_idx [src ] = cur_count_
32+ dict_idx_id [cur_count_ ] = src
33+ cur_count_ += 1
34+ if dst not in dict_id_idx :
35+ dict_id_idx [dst ] = cur_count_
36+ dict_idx_id [cur_count_ ] = dst
37+ cur_count_ += 1
38+ raw_ += [[dict_id_idx [src ], dict_id_idx [dst ]]]
39+
40+ raw = np .array (raw_ )
41+ COL = raw .shape [1 ]
42+
43+ if COL < 2 :
44+ sys .exit ('[Input format error.]' )
45+ elif COL == 2 :
46+ print ('[unweighted graph detected.]' )
47+ rows = raw [:,0 ]
48+ cols = raw [:,1 ]
49+ weis = np .ones (len (rows ))
50+
51+ elif COL == 3 :
52+ print ('[weighted graph detected.]' )
53+ rows = raw [:,0 ]
54+ cols = raw [:,1 ]
55+ weis = raw [:,2 ]
56+
57+ check_eq = True
58+ max_id = int (max (max (rows ), max (cols )))
59+ num_nodes = max_id + 1
60+
61+ nodes_to_embed = range (int (max_id )+ 1 )
62+
63+ if max (rows ) != max (cols ):
64+ rows = np .append (rows ,max (max (rows ), max (cols )))
65+ cols = np .append (cols ,max (max (rows ), max (cols )))
66+ weis = np .append (weis , 0 )
67+ check_eq = False
68+
69+ adj_matrix = sps .lil_matrix ( sps .csc_matrix ((weis , (rows , cols ))))
70+
71+ CAT_DICT = defaultdict (set )
72+ ID_CAT_DICT = dict ()
73+ for i in range (num_nodes ):
74+ CAT_DICT [1 ].add (i )
75+ ID_CAT_DICT [i ] = 1
76+ unique_cat = [1 ]
77+
1978 ######################################################
2079 # Multi-Lens starts.
2180 ######################################################
81+
2282 g_sums = []
23- rep_method = RepMethod (method = "hetero" , bucket_max_value = 30 ,
24- num_buckets = num_buckets , operators = op , use_total = len (op ))
83+
84+ neighbor_list = construct_neighbor_list (adj_matrix , nodes_to_embed )
85+ neighbor_list_r = construct_neighbor_list (adj_matrix .T , nodes_to_embed )
86+
87+ graph = Graph (adj_matrix = adj_matrix , max_id = max_id , num_nodes = num_nodes , base_features = base_features ,
88+ neighbor_list = neighbor_list , directed = directed , cat_dict = CAT_DICT , id_cat_dict = ID_CAT_DICT , unique_cat = unique_cat , check_eq = check_eq )
89+
90+ rep_method = RepMethod (method = "hetero" , bucket_max_value = 30 , num_buckets = num_buckets , operators = op , use_total = len (op ))
2591
2692 ########################################
2793 # Step 1: get base features
2894 ########################################
29- init_feature_matrix = get_init_features (
30- self .graph , base_features , nodes_to_embed )
31- init_feature_matrix_seq = get_seq_features (
32- self .graph , rep_method , input_dense_matrix = init_feature_matrix , nodes_to_embed = nodes_to_embed )
95+ init_feature_matrix = get_init_features (graph , base_features , nodes_to_embed )
96+ init_feature_matrix_seq = get_seq_features (graph , rep_method , input_dense_matrix = init_feature_matrix , nodes_to_embed = nodes_to_embed )
3397
34- Kis = get_Kis (init_feature_matrix_seq , dim , L )
35- print (Kis )
3698
37- feature_matrix_emb , g_sum = feature_layer_evaluation_embedding (
38- self .graph , rep_method , feature_matrix = init_feature_matrix_seq , k = Kis [0 ])
99+ Kis = get_Kis (init_feature_matrix_seq , dim , L )
100+ # print Kis
101+
102+ feature_matrix_emb , g_sum = feature_layer_evaluation_embedding (graph , rep_method , feature_matrix = init_feature_matrix_seq , k = Kis [0 ])
39103
40104 g_sums .append (g_sum )
41105
@@ -49,20 +113,20 @@ def train(self):
49113 feature_matrix = init_feature_matrix
50114
51115 for i in range (L ):
52- print ('[Current layer] ' + str (i ))
53- print ('[feature_matrix shape] ' + str (feature_matrix .shape ))
116+ print ('[Current layer]' , str (i ))
117+ print ('[feature_matrix shape]' , str (feature_matrix .shape ))
54118
55- feature_matrix_new = search_feature_layer (
56- self .graph , rep_method , base_feature_matrix = feature_matrix )
57- feature_matrix_new_seq = get_seq_features (
58- self .graph , rep_method , input_dense_matrix = feature_matrix_new , nodes_to_embed = nodes_to_embed )
59- feature_matrix_new_emb , g_new_sum = feature_layer_evaluation_embedding (
60- self .graph , rep_method , feature_matrix = feature_matrix_new_seq , k = Kis [i + 1 ])
119+ feature_matrix_new = search_feature_layer (graph , rep_method , base_feature_matrix = feature_matrix )
120+ feature_matrix_new_seq = get_seq_features (graph , rep_method , input_dense_matrix = feature_matrix_new , nodes_to_embed = nodes_to_embed )
121+ feature_matrix_new_emb , g_new_sum = feature_layer_evaluation_embedding (graph , rep_method , feature_matrix = feature_matrix_new_seq , k = Kis [i + 1 ])
61122
62123 feature_matrix = feature_matrix_new
63124 rep_new = feature_matrix_new_emb
64125 rep = np .concatenate ((rep , rep_new ), axis = 1 )
65126
66127 g_sums .append (g_new_sum )
67128
68- self .embeddings = g_sum
129+ N , K = rep .shape
130+ self .embeddings = dict ()
131+ for i in range (N ):
132+ self .embeddings [dict_idx_id [i ]] = rep [i , :].tolist ()
0 commit comments