Skip to content

Commit fb22fc8

Browse files
committed
add visualization for classification and clustering
1 parent 9702a21 commit fb22fc8

5 files changed

Lines changed: 1721 additions & 1615 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

semb/.DS_Store

0 Bytes
Binary file not shown.

semb/evaluations/visualization.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from ..exceptions import *
2+
import networkx as nx
3+
import pandas as pd
4+
import matplotlib.pyplot as plt
5+
6+
def visualize_classification(list_results, metric="f1_macro", error_bar=True, rotation_deg=45):
7+
"""
8+
Visualize the result for classification using bar plot.
9+
10+
Arguments:
11+
list_results {list of tuples} -- [("method_name", dict_result)], where the dict_result is the returned dict
12+
from the perform_classification() function
13+
metric -- The metric to use for making the plot. Please choose from ["accuracy", "f1_macro", "f1_micro", "f1_macro", "auc_micro", "auc_macro"]
14+
error_bar -- bool. Whether to include the error bar in the plot.
15+
rotation_deg -- int. The rotation degree for the xtick labels
16+
"""
17+
18+
# Check that metric belongs to the classification results
19+
if metric not in ["accuracy", "f1_macro", "f1_micro", "f1_macro", "auc_micro", "auc_macro"]:
20+
raise MethodKeywordUnAllowedException("Please choose metric from [accuracy, f1_macro, f1_micro, f1_macro, auc_micro, auc_macro].")
21+
22+
if len(list_results) == 0:
23+
raise InputFormatErrorException("Input length 0!")
24+
25+
for cur_item in list_results:
26+
if (len(cur_item) != 2):
27+
raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]")
28+
29+
if (not isinstance(cur_item[0], str)) or (not isinstance(cur_item[1], dict)):
30+
raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]")
31+
32+
if "overall" not in cur_item[1]:
33+
raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_classification() or perform_clustering()")
34+
35+
36+
list_methods = [i[0] for i in list_results]
37+
list_evaluation = [i[1]['overall'][metric]['mean'] for i in list_results]
38+
list_error = [i[1]['overall'][metric]['std'] for i in list_results]
39+
plt.close()
40+
plt.figure(figsize=(8, 2.5), dpi=300)
41+
plt.style.use('ggplot')
42+
if error_bar:
43+
plt.bar(list_methods,
44+
list_evaluation,
45+
yerr = list_error,
46+
capsize=10)
47+
else:
48+
plt.bar(list_methods,
49+
list_evaluation,
50+
capsize=10)
51+
plt.ylim(0, 1)
52+
plt.ylabel(metric)
53+
plt.xlabel("Methods")
54+
plt.xticks(rotation=rotation_deg)
55+
plt.show()
56+
57+
58+
def visualize_clustering(list_results, metric="purity", rotation_deg=45):
59+
"""
60+
Visualize the result for clustering using bar plot.
61+
62+
Arguments:
63+
list_results {list of tuples} -- [("method_name", dict_result)], where the dict_result is the returned dict
64+
from the perform_clustering() function
65+
metric -- The metric to use for making the plot. Please choose from ["purity", "nmi"]
66+
rotation_deg -- int. The rotation degree for the xtick labels
67+
"""
68+
69+
# Check that metric belongs to the classification results
70+
if metric not in ["purity", "nmi"]:
71+
raise MethodKeywordUnAllowedException("Please choose metric from [purity, nmi].")
72+
73+
if len(list_results) == 0:
74+
raise InputFormatErrorException("Input length 0!")
75+
76+
for cur_item in list_results:
77+
if (len(cur_item) != 2):
78+
raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]")
79+
80+
if (not isinstance(cur_item[0], str)) or (not isinstance(cur_item[1], dict)):
81+
raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]")
82+
83+
if "overall" not in cur_item[1]:
84+
raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_classification() or perform_clustering()")
85+
86+
87+
list_methods = [i[0] for i in list_results]
88+
list_evaluation = [i[1]['overall'][metric] for i in list_results]
89+
plt.close()
90+
plt.figure(figsize=(8, 2.5), dpi=300)
91+
plt.style.use('ggplot')
92+
93+
plt.bar(list_methods,
94+
list_evaluation,
95+
capsize=10)
96+
plt.ylim(0, 1)
97+
plt.ylabel(metric)
98+
plt.xlabel("Methods")
99+
plt.xticks(rotation=rotation_deg)
100+
plt.show()
101+

0 commit comments

Comments
 (0)