11import math
22from os import PathLike
3- from typing import Union , Sequence , Callable
3+ from typing import Union , Sequence , Callable , Dict
44
55import matplotlib as mpl
66import numpy as np
@@ -128,9 +128,6 @@ class method for conveniently constructing DotPlot from tidy data
128128 def __get_coordinates (self , size_factor ):
129129 X = list (range (1 , self .width_item + 1 )) * self .height_item
130130 Y = sorted (list (range (1 , self .height_item + 1 )) * self .width_item )
131- self .resized_size_data = self .size_data .applymap (func = lambda x : x * size_factor )
132- if self .circle_data is not None :
133- self .resized_circle_data = self .circle_data .applymap (func = lambda x : x * size_factor )
134131 return X , Y
135132
136133 def __draw_dotplot (self , ax , size_factor , cmap , vmin , vmax , ** kws ):
@@ -204,11 +201,39 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, c
204201 ax .spines ['left' ].set_visible (False )
205202 ax .spines ['right' ].set_visible (False )
206203
204+ def __preprocess_data (self , size_factor , cluster_row = False , cluster_col = False , ** kwargs ):
205+
206+ method = kwargs .get ('cluster_method' , 'ward' )
207+ metric = kwargs .get ('cluster_metric' , 'eulidean' )
208+ n_clusters = kwargs .get ('cluster_n' , None )
209+
210+ if cluster_row or cluster_col :
211+ from .hierarchical import cluster_hierarchy
212+ if cluster_row :
213+ _index = cluster_hierarchy (self .size_data , axis = 0 , method = method ,
214+ metric = metric , n_clusters = n_clusters )
215+ else :
216+ _index = cluster_hierarchy (self .size_data , axis = 1 , method = method ,
217+ metric = metric , n_clusters = n_clusters )
218+ for item in self .__slots__ :
219+ if hasattr (self , item ):
220+ obj_attr = getattr (self , item )
221+ if isinstance (obj_attr , pd .DataFrame ):
222+ if cluster_row :
223+ obj_attr = obj_attr .loc [_index , :]
224+ if cluster_col :
225+ obj_attr = obj_attr .loc [:, _index ]
226+ setattr (self , item , obj_attr )
227+ self .resized_size_data = self .size_data .applymap (func = lambda x : x * size_factor )
228+ if self .circle_data is not None :
229+ self .resized_circle_data = self .circle_data .applymap (func = lambda x : x * size_factor )
230+
207231 def plot (self , size_factor : float = 15 ,
208232 vmin : float = 0 , vmax : float = None ,
209233 path : Union [PathLike , None ] = None ,
210234 cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ,
211- ** kwargs
235+ cluster_row : bool = False , cluster_col : bool = False ,
236+ cluster_kws : Union [Dict , None ] = None , ** kwargs
212237 ):
213238 """
214239
@@ -219,8 +244,14 @@ def plot(self, size_factor: float = 15,
219244 :param cmap: color map supported by matplotlib
220245 :param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
221246 other kwargs are passed to `matplotlib.Axes.scatter`
247+ :param cluster_row, whether to cluster the row
248+ :param cluster_col, whether to cluster the row
249+ :param cluster_kws, key args for cluster, including `cluster_method`, `cluster_metric`, 'cluster_n'
222250 :return:
223251 """
252+ self .__preprocess_data (size_factor , cluster_row = cluster_row , cluster_col = cluster_col ,
253+ ** cluster_kws if cluster_kws is not None else {}
254+ )
224255 ax , ax_cbar , ax_sizes , ax_circles , fig = self .__get_figure ()
225256 scatter , sct_circle = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
226257 self .__draw_legend (ax_sizes , scatter , size_factor ,
0 commit comments