Skip to content

Commit e35eaf5

Browse files
committed
add table display
1 parent 136e39f commit e35eaf5

10 files changed

Lines changed: 12366 additions & 1541 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

Example_new.ipynb

Lines changed: 10290 additions & 219 deletions
Large diffs are not rendered by default.

semb/.DS_Store

0 Bytes
Binary file not shown.

semb/evaluations/centrality_correlation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ..exceptions import UnimplementedException, MethodKeywordUnAllowedException
1+
from ..exceptions import *
22
import networkx as nx
33
import numpy as np
44
import sklearn
@@ -33,7 +33,7 @@ def get_centrality(graph, centrality='degree', **kwargs):
3333
"""
3434

3535
if not isinstance(graph, nx.classes.graph.Graph):
36-
raise InputFormatError("Please input graph as NetworkX.graph object")
36+
raise InputFormatErrorException("Please input graph as NetworkX.graph object")
3737

3838
if centrality == 'degree':
3939
dict_centrality = dict(graph.degree())
@@ -67,7 +67,7 @@ def centrality_correlation(graph, dict_embeddings, centrality='degree', similari
6767
"""
6868

6969
if not isinstance(graph, nx.classes.graph.Graph):
70-
raise InputFormatError("Please input graph as NetworkX.graph object")
70+
raise InputFormatErrorException("Please input graph as NetworkX.graph object")
7171

7272
if centrality == 'degree':
7373
dict_centrality = dict(graph.degree())

semb/evaluations/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ..exceptions import UnimplementedException, MethodKeywordUnAllowedException
1+
from ..exceptions import *
22

33
import networkx as nx
44
import numpy as np

semb/evaluations/clustering.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ..exceptions import UnimplementedException, MethodKeywordUnAllowedException
1+
from ..exceptions import *
22

33
import networkx as nx
44
import numpy as np
@@ -19,12 +19,8 @@ def purity_score(y_true, y_pred):
1919
def kmeans_best_result(X, y, n_clusters):
2020
list_purity = list()
2121
list_nmi = list()
22-
2322
kmeans = KMeans(n_clusters=n_clusters, n_init=min(len(y), 1000), init='k-means++').fit(X)
24-
25-
list_purity += [purity_score(y, kmeans.labels_)]
26-
list_nmi += [normalized_mutual_info_score(y, kmeans.labels_)]
27-
return {'purity': list_purity, 'nmi': list_nmi}
23+
return {'purity': purity_score(y, kmeans.labels_), 'nmi': normalized_mutual_info_score(y, kmeans.labels_)}
2824

2925

3026
def perform_clustering(dict_embeddings, dict_labels, **kwargs):

semb/evaluations/utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from ..exceptions import UnimplementedException, MethodKeywordUnAllowedException
1+
from ..exceptions import *
22
import networkx as nx
3+
import pandas as pd
34

45
def get_label(input_dir, delimeter = ' ' ,**kwargs):
56
"""
@@ -33,3 +34,56 @@ def get_label(input_dir, delimeter = ' ' ,**kwargs):
3334
for key, val in dict_counter.items():
3435
print(">>> Label", key, 'appears', val, 'times')
3536
return dict_labels
37+
38+
def concatenate_result_pd(list_results):
39+
"""
40+
Concatenate the results from the clustering / classifcation test
41+
Arguments:
42+
list_results {list of tuples} -- [("method_name", dict_result)], where the dict_result is the returned dict
43+
from the perform_clustering() and perform_classification() functions
44+
45+
Return:
46+
pd_results -- a pandas table showing the results
47+
48+
"""
49+
# Perform input checking on the list_results
50+
if len(list_results) == 0:
51+
raise InputFormatErrorException("Input length 0!")
52+
53+
for cur_item in list_results:
54+
if (len(cur_item) != 2):
55+
raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]")
56+
57+
if (not isinstance(cur_item[0], str)) or (not isinstance(cur_item[1], dict)):
58+
raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]")
59+
60+
if "overall" not in cur_item[1]:
61+
raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_classification() or perform_clustering()")
62+
63+
64+
pd_results = pd.DataFrame()
65+
pd_results['methods'] = [i[0] for i in list_results]
66+
67+
68+
69+
# Peform checking on whether classifcation or clustering is tested
70+
if 'accuracy' in list_results[0][1]['overall']:
71+
# Classification
72+
for cur_item in list_results:
73+
if 'accuracy' not in cur_item[1]['overall']:
74+
raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_classification()")
75+
76+
for metric in ['accuracy', 'f1_macro', 'f1_micro', 'auc_micro', 'auc_macro']:
77+
for value in ['mean', 'std']:
78+
pd_results[metric + '_' + value] = [i[1]['overall'][metric][value] for i in list_results]
79+
else:
80+
# Clustering
81+
for cur_item in list_results:
82+
if 'purity' not in cur_item[1]['overall']:
83+
raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_clustering()")
84+
85+
for metric in ['purity', 'nmi']:
86+
pd_results[metric] = [i[1]['overall'][metric] for i in list_results]
87+
88+
return pd_results
89+

semb/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ class MethodNotExistException(Exception):
1010
class MethodKeywordUnAllowedException(Exception):
1111
pass
1212

13-
class InputFormatError(Exception):
13+
class InputFormatErrorException(Exception):
1414
pass

0 commit comments

Comments
 (0)