@@ -48,7 +48,6 @@ def __init__(self, df_size: pd.DataFrame,
4848 self .color_data = df_color
4949 self .circle_data = df_circle
5050 self .height_item , self .width_item = df_size .shape
51- # TODO code logic need to argument
5251 self .row_colors = row_colors
5352 self .col_colors = col_colors
5453 self .resized_size_data : Union [pd .DataFrame , None ] = None
@@ -227,29 +226,33 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, c
227226 ax .spines ['left' ].set_visible (False )
228227 ax .spines ['right' ].set_visible (False )
229228
230- def __preprocess_data (self , size_factor , cluster_row = False , cluster_col = False , ** kwargs ):
231-
229+ def __cluster_matrix (self , axis = 0 , ** kwargs ):
230+ from . hierarchical import cluster_hierarchy
232231 method = kwargs .get ('cluster_method' , 'ward' )
233232 metric = kwargs .get ('cluster_metric' , 'euclidean' )
234233 n_clusters = kwargs .get ('cluster_n' , None )
234+ _index = cluster_hierarchy (self .size_data , axis = axis , method = method ,
235+ metric = metric , n_clusters = n_clusters )
236+ obj_data = self .__dict__ .copy ()
237+ for _obj_attr , _obj in obj_data .items ():
238+ if (not _obj_attr .startswith ('__' )) and isinstance (_obj , (pd .DataFrame , pd .Series )):
239+ if _obj_attr in ('row_colors' , 'col_colors' ): # TODO may change the action in the future
240+ continue
241+ if axis == 0 :
242+ _obj = _obj .loc [_index , :]
243+ elif axis == 1 :
244+ _obj = _obj .loc [:, _index ]
245+ else :
246+ raise ValueError ('axis should be 0 or 1.' )
247+ setattr (self , _obj_attr , _obj )
248+
249+ def __preprocess_data (self , size_factor , cluster_row = False , cluster_col = False , ** kwargs ):
235250
236251 if cluster_row or cluster_col :
237- from .hierarchical import cluster_hierarchy
238252 if cluster_row :
239- _index = cluster_hierarchy (self .size_data , axis = 0 , method = method ,
240- metric = metric , n_clusters = n_clusters )
241- else :
242- _index = cluster_hierarchy (self .size_data , axis = 1 , method = method ,
243- metric = metric , n_clusters = n_clusters )
244- obj_data = self .__dict__ .copy ()
245- for _obj_attr , _obj in obj_data .items ():
246- if not _obj_attr .startswith ('__' ):
247- if isinstance (_obj , pd .DataFrame ):
248- if cluster_row :
249- _obj = _obj .loc [_index , :]
250- if cluster_col :
251- _obj = _obj .loc [:, _index ]
252- setattr (self , _obj_attr , _obj )
253+ self .__cluster_matrix (axis = 0 , ** kwargs )
254+ if cluster_col :
255+ self .__cluster_matrix (axis = 1 , ** kwargs )
253256 self .resized_size_data = self .size_data .applymap (func = lambda x : x * size_factor )
254257 if self .circle_data is not None :
255258 self .resized_circle_data = self .circle_data .applymap (func = lambda x : x * size_factor )
0 commit comments