|
1 | | -from ..exceptions import UnimplementedException, MethodKeywordUnAllowedException |
| 1 | +from ..exceptions import * |
2 | 2 | import networkx as nx |
| 3 | +import pandas as pd |
3 | 4 |
|
4 | 5 | def get_label(input_dir, delimeter = ' ' ,**kwargs): |
5 | 6 | """ |
@@ -33,3 +34,56 @@ def get_label(input_dir, delimeter = ' ' ,**kwargs): |
33 | 34 | for key, val in dict_counter.items(): |
34 | 35 | print(">>> Label", key, 'appears', val, 'times') |
35 | 36 | return dict_labels |
| 37 | + |
| 38 | +def concatenate_result_pd(list_results): |
| 39 | + """ |
| 40 | + Concatenate the results from the clustering / classifcation test |
| 41 | + Arguments: |
| 42 | + list_results {list of tuples} -- [("method_name", dict_result)], where the dict_result is the returned dict |
| 43 | + from the perform_clustering() and perform_classification() functions |
| 44 | + |
| 45 | + Return: |
| 46 | + pd_results -- a pandas table showing the results |
| 47 | +
|
| 48 | + """ |
| 49 | + # Perform input checking on the list_results |
| 50 | + if len(list_results) == 0: |
| 51 | + raise InputFormatErrorException("Input length 0!") |
| 52 | + |
| 53 | + for cur_item in list_results: |
| 54 | + if (len(cur_item) != 2): |
| 55 | + raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]") |
| 56 | + |
| 57 | + if (not isinstance(cur_item[0], str)) or (not isinstance(cur_item[1], dict)): |
| 58 | + raise InputFormatErrorException("Please input the results as list of tuples, i.e. [(\"method_name\", dict_result)]") |
| 59 | + |
| 60 | + if "overall" not in cur_item[1]: |
| 61 | + raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_classification() or perform_clustering()") |
| 62 | + |
| 63 | + |
| 64 | + pd_results = pd.DataFrame() |
| 65 | + pd_results['methods'] = [i[0] for i in list_results] |
| 66 | + |
| 67 | + |
| 68 | + |
| 69 | + # Peform checking on whether classifcation or clustering is tested |
| 70 | + if 'accuracy' in list_results[0][1]['overall']: |
| 71 | + # Classification |
| 72 | + for cur_item in list_results: |
| 73 | + if 'accuracy' not in cur_item[1]['overall']: |
| 74 | + raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_classification()") |
| 75 | + |
| 76 | + for metric in ['accuracy', 'f1_macro', 'f1_micro', 'auc_micro', 'auc_macro']: |
| 77 | + for value in ['mean', 'std']: |
| 78 | + pd_results[metric + '_' + value] = [i[1]['overall'][metric][value] for i in list_results] |
| 79 | + else: |
| 80 | + # Clustering |
| 81 | + for cur_item in list_results: |
| 82 | + if 'purity' not in cur_item[1]['overall']: |
| 83 | + raise InputFormatErrorException("Invalid input. Please make sure that the input result is generated from perform_clustering()") |
| 84 | + |
| 85 | + for metric in ['purity', 'nmi']: |
| 86 | + pd_results[metric] = [i[1]['overall'][metric] for i in list_results] |
| 87 | + |
| 88 | + return pd_results |
| 89 | + |
0 commit comments