1818class DotPlot (object ):
1919 DEFAULT_ITEM_HEIGHT = 0.3
2020 DEFAULT_ITEM_WIDTH = 0.3
21- DEFAULT_LEGENDS_WIDTH = .45
22- MIN_FIGURE_HEIGHT = 3
21+ DEFAULT_LEGENDS_WIDTH = .6
22+ MIN_FIGURE_HEIGHT = 3.5
2323 DEFAULT_BAND_ITEM_LENGTH = .2
2424
2525 def __init__ (self , df_size : pd .DataFrame ,
@@ -35,9 +35,9 @@ def __init__(self, df_size: pd.DataFrame,
3535 :param df_size: the DataFrame object represents the scatter size in dotplot
3636 :param df_color: the DataFrame object represents the color in dotplot
3737 """
38- __slots__ = ['size_data' , 'resized_size_data' ,
39- 'color_data ' , 'height_item ' , 'width_item ' ,
40- 'circle_data' , 'resized_circle_data' , 'row_colors' , 'col_colors' , 'mask_frames '
38+ __slots__ = ['size_data' , 'resized_size_data' , 'color_data' , 'height_item' , 'width_item' ,
39+ 'circle_data ' , 'resized_circle_data ' , 'row_colors' , 'col_colors' , 'mask_frames ' ,
40+ 'figure '
4141 ]
4242 if df_color is not None and df_size .shape != df_color .shape :
4343 raise ValueError ('df_size and df_color should have the same dimension' )
@@ -59,6 +59,7 @@ def __init__(self, df_size: pd.DataFrame,
5959 self .resized_size_data : Union [pd .DataFrame , None ] = None
6060 self .resized_circle_data : Union [pd .DataFrame , None ] = None
6161 self .mask_frames = mask_frames
62+ self .figure = None
6263 if (self .row_colors is None ) and (self .col_colors is None ):
6364 self .col_colors = pd .DataFrame ({'' : ['#FFFFFF' ] * df_size .shape [1 ]},
6465 index = df_size .columns .tolist ())
@@ -74,7 +75,8 @@ def __get_figure(self):
7475 (_text_max + self .width_item ) * self .DEFAULT_ITEM_WIDTH
7576 )
7677 figure_height = max ([self .MIN_FIGURE_HEIGHT , mainplot_height ])
77- figure_width = mainplot_width + self .DEFAULT_LEGENDS_WIDTH
78+ n_group = len (np .unique (self .mask_frames .to_numpy ())) if self .mask_frames is not None else 1
79+ figure_width = mainplot_width + self .DEFAULT_LEGENDS_WIDTH * n_group
7880 band_width , band_height = 0. , 0.
7981 if self .row_colors is not None :
8082 band_width = self .DEFAULT_BAND_ITEM_LENGTH * self .row_colors .shape [1 ]
@@ -85,28 +87,26 @@ def __get_figure(self):
8587
8688 plt .style .use ('seaborn-white' )
8789 fig = plt .figure (figsize = (figure_width , figure_height ))
90+ self .figure = fig
8891 gs = gridspec .GridSpec (nrows = 2 , ncols = 3 , wspace = 0.05 , hspace = 0.02 ,
89- width_ratios = [mainplot_width , band_width , self .DEFAULT_LEGENDS_WIDTH ],
92+ width_ratios = [mainplot_width , band_width , self .DEFAULT_LEGENDS_WIDTH * n_group ],
9093 height_ratios = [band_height , mainplot_height ]
9194 )
9295 ax = fig .add_subplot (gs [1 , 0 ])
9396 ax_row_bands = fig .add_subplot (gs [1 , 1 ])
9497 ax_col_bands = fig .add_subplot (gs [0 , 0 ])
9598 ax_abandon = fig .add_subplot (gs [0 , 1 ])
9699 legend_gs = gridspec .GridSpecFromSubplotSpec (3 , 1 , hspace = .1 , subplot_spec = gs [1 , 2 ])
97- ax_sizes = fig . add_subplot ( legend_gs [0 , 0 ])
98- ax_circles = fig . add_subplot ( legend_gs [1 , 0 ])
99- ax_cbar = fig . add_subplot ( legend_gs [2 , 0 ])
100+ gs_sizes_legend = legend_gs [0 , 0 ]
101+ gs_cbar_legend = legend_gs [2 , 0 ]
102+ gs_circles_legend = legend_gs [1 , 0 ]
100103
101- _ , _ = ax_sizes .axis ('off' ), ax_circles .axis ('off' )
102104 if self .col_colors is None :
103105 ax_col_bands .axis ('off' )
104106 if self .row_colors is None :
105107 ax_row_bands .axis ('off' )
106- if self .color_data is None :
107- ax_cbar .axis ('off' )
108108 ax_abandon .axis ('off' )
109- return ax , ax_cbar , ax_sizes , ax_circles , ax_row_bands , ax_col_bands , fig
109+ return ax , gs_cbar_legend , gs_sizes_legend , gs_circles_legend , ax_row_bands , ax_col_bands , fig
110110
111111 # TODO update with the newest version of __init__
112112 @classmethod
@@ -166,19 +166,23 @@ def __get_coordinates(self):
166166 Y = sorted (list (range (1 , self .height_item + 1 )) * self .width_item )
167167 return X , Y
168168
169- def __draw_dotplot (self , ax , cmap , vmin , vmax , ** kws ):
170- dot_color = kws .get ('dot_color' , '#58000C' )
171- circle_color = kws .get ('circle_color' , '#000000' )
172- kws = kws .copy ()
173- for _value in ['dot_title' , 'circle_title' , 'colorbar_title' , 'dot_color' , 'circle_color' ]:
174- _ = kws .pop (_value , None )
169+ def __draw_dotplot (self , ax , cmap , vmin , vmax , size_factor , * , gs_cbar : mpl .gridspec .GridSpec ,
170+ gs_sizes : mpl .gridspec .GridSpec , gs_circles : mpl .gridspec .GridSpec , ** kws ):
175171 X , Y = self .__get_coordinates ()
172+ kws = kws .copy ()
173+ dot_color = kws .pop ('dot_color' , '#58000C' )
174+ dot_title = kws .pop ('dot_title' , 'Sizes' )
175+ circle_color = kws .pop ('circle_color' , '#000000' )
176+ circle_title = kws .pop ('circle_title' , 'Circles' )
177+ colorbar_title = kws .pop ('colorbar_title' , '-log10(Pvalue)' )
178+
176179 resized_size_data_array = self .resized_size_data .values .flatten ()
177180 color_data_array_or_str = dot_color if self .color_data is None else self .color_data .values .flatten ()
178181 sct : Union [List [mpl .collections .PathCollection ], mpl .collections .PathCollection ] = []
179182 if self .mask_frames is not None :
180- masks , n_masks = self .resolve_mask (self .mask_frames )
183+ masks , n_masks , mask_groups = self .resolve_mask (self .mask_frames )
181184 if isinstance (color_data_array_or_str , np .ndarray ):
185+ vmax = np .max (self .color_data .values .flatten ()) if vmax is None else vmax
182186 if isinstance (cmap , (str , mpl .colors .Colormap )):
183187 cmap = CMAPS_PRESET
184188 if len (cmap ) < n_masks :
@@ -188,6 +192,8 @@ def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
188192 _sct = ax .scatter (X , Y , c = color_data_array_or_str , s = masked_resized_size_data_array ,
189193 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = _cmap , ** kws )
190194 sct .append (_sct )
195+ self .__draw_color_bar (gs_cbar , sct , cmap , vmin , vmax , ylabel = colorbar_title )
196+ self .__draw_legend (gs_sizes , sct [0 ], size_factor , color = dot_color , title = dot_title )
191197 else :
192198 if isinstance (dot_color , str ):
193199 dot_color = COLORS_PRESET
@@ -198,14 +204,19 @@ def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
198204 _sct = ax .scatter (X , Y , c = _dot_color , s = masked_resized_size_data_array ,
199205 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , ** kws )
200206 sct .append (_sct )
201-
207+ self . __draw_legend ( gs_sizes , sct , size_factor , color = dot_color , title = mask_groups )
202208 else :
209+ vmax = np .max (self .color_data .values .flatten ()) if vmax is None else vmax
203210 sct = ax .scatter (X , Y , c = color_data_array_or_str , s = resized_size_data_array ,
204211 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap , ** kws )
205- sct_circle = None
212+ if self .color_data is not None :
213+ self .__draw_color_bar (gs_cbar , sct , cmap , vmin , vmax , ylabel = colorbar_title )
214+ self .__draw_legend (gs_sizes , sct , size_factor , color = dot_color , title = dot_title )
206215 if self .circle_data is not None :
207216 sct_circle = ax .scatter (X , Y , c = 'none' , s = self .resized_circle_data .values .flatten (),
208217 edgecolors = circle_color , marker = 'o' , vmin = vmin , vmax = vmax , linestyle = '--' )
218+ self .__draw_legend (gs_circles , sct_circle , size_factor , color = circle_color , title = circle_title ,
219+ circle = True )
209220 width , height = self .width_item , self .height_item
210221 ax .set_xlim ([0.5 , width + 0.5 ])
211222 ax .set_ylim ([0.6 , height + 0.6 ])
@@ -215,49 +226,76 @@ def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
215226 ax .set_yticklabels (self .size_data .index .tolist ())
216227 ax .tick_params (axis = 'y' , length = 5 , labelsize = 15 , direction = 'out' )
217228 ax .tick_params (axis = 'x' , length = 5 , labelsize = 15 , direction = 'out' )
218- return sct , sct_circle
219229
220- @staticmethod
221- def __draw_color_bar (ax , sct : mpl .collections .PathCollection , cmap , vmin , vmax , ylabel ):
222- gradient = np .linspace (1 , 0 , 500 )
223- gradient = gradient [:, np .newaxis ]
224- _ = ax .imshow (gradient , aspect = 'auto' , cmap = cmap , origin = 'upper' , extent = [.2 , 0.3 , 0.5 , - 0.5 ])
225- ax .set_xticks ([])
226- ax .set_yticks ([])
227- ax_cbar2 = ax .twinx ()
228- _ = ax_cbar2 .set_yticks ([0 , 1000 ])
229- if vmax is None :
230- vmax = math .ceil (sct .get_array ().max ())
231- if vmin is None :
232- vmin = math .floor (sct .get_array ().min ())
233- _ = ax_cbar2 .set_yticklabels ([vmin , vmax ])
234- _ = ax_cbar2 .set_ylabel (ylabel )
230+ def __draw_color_bar (self , gs : mpl .gridspec .GridSpec ,
231+ sct : Union [mpl .collections .PathCollection , Sequence [mpl .collections .PathCollection ]],
232+ cmap , vmin , vmax , ylabel ):
233+ def _draw_color_bars_core (axes , path_collection : mpl .collections .PathCollection , _cmap , _vmin , _vmax , _ylabel ):
234+ gradient = np .linspace (1 , 0 , 500 )
235+ gradient = gradient [:, np .newaxis ]
236+ _ = axes .imshow (gradient , aspect = 'auto' , cmap = _cmap , origin = 'upper' )
237+ axes .set_xticks ([])
238+ axes .set_yticks ([])
239+ ax_cbar2 = axes .twinx ()
240+ if _vmax is None :
241+ _vmax = math .ceil (path_collection .get_array ().max ())
242+ if _vmin is None :
243+ _vmin = math .floor (path_collection .get_array ().min ())
244+ if _ylabel :
245+ _ = ax_cbar2 .set_yticks ([0 , 1000 ])
246+ _ = ax_cbar2 .set_yticklabels ([_vmin , _vmax ])
247+ else :
248+ _ = ax_cbar2 .set_yticks ([])
249+ _ = ax_cbar2 .set_ylabel (_ylabel )
235250
236- @staticmethod
237- def __draw_legend (ax , sct : mpl .collections .PathCollection , size_factor , title , circle = False , color = None ):
238- handles , labels = sct .legend_elements (prop = "sizes" , alpha = 1 ,
239- func = lambda x : x / size_factor ,
240- color = color
241- )
242- if len (handles ) > 3 :
243- handles = np .asarray (handles )
244- labels = np .asarray (labels )
245- handles = handles [[0 , math .ceil (len (handles ) / 2 ), - 1 ]]
246- labels = labels [[0 , math .ceil (len (labels ) / 2 ), - 1 ]]
247- if circle :
248- from matplotlib .lines import Line2D
249- for i , _item in enumerate (handles ):
250- xdata , ydata = _item .get_data ()
251- marker_size = _item .get_markersize ()
252- handles [i ] = Line2D (xdata , ydata , color = 'white' , marker = '$\u25CC $' ,
253- markeredgecolor = color , markersize = marker_size )
254- _ = ax .legend (handles , labels , title = title , loc = 'center left' ) # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
255- ax .set_xticks ([])
256- ax .set_yticks ([])
257- ax .spines ['top' ].set_visible (False )
258- ax .spines ['bottom' ].set_visible (False )
259- ax .spines ['left' ].set_visible (False )
260- ax .spines ['right' ].set_visible (False )
251+ fig = self .figure
252+ if isinstance (sct , mpl .collections .PathCollection ):
253+ ax = fig .add_subplot (gs )
254+ _draw_color_bars_core (ax , sct , cmap , vmin , vmax , ylabel )
255+ else :
256+ new_gs = gridspec .GridSpecFromSubplotSpec (1 , len (sct ), wspace = .2 , subplot_spec = gs )
257+ n_sct = len (sct ) - 1
258+ for i , (_sct , _cmap ) in enumerate (zip (sct , cmap )):
259+ ax = fig .add_subplot (new_gs [0 , i ])
260+ _ylabel = ylabel if n_sct == i else ''
261+ _draw_color_bars_core (ax , _sct , _cmap , vmin , vmax , _ylabel = _ylabel )
262+
263+ def __draw_legend (self , gs : mpl .gridspec .GridSpec ,
264+ sct : mpl .collections .PathCollection ,
265+ size_factor , title , circle = False , color = None ):
266+ def __draw_legend_core (_ax , _sct , _size_factor , _title , _circle = False , _color = None ):
267+ handles , labels = _sct .legend_elements (prop = "sizes" , alpha = 1 ,
268+ func = lambda x : x / _size_factor ,
269+ color = _color
270+ )
271+ if len (handles ) > 3 :
272+ handles = np .asarray (handles )
273+ labels = np .asarray (labels )
274+ handles = handles [[0 , math .ceil (len (handles ) / 2 ), - 1 ]]
275+ labels = labels [[0 , math .ceil (len (labels ) / 2 ), - 1 ]]
276+ if _circle :
277+ from matplotlib .lines import Line2D
278+ for j , _item in enumerate (handles ):
279+ xdata , ydata = _item .get_data ()
280+ marker_size = _item .get_markersize ()
281+ handles [j ] = Line2D (xdata , ydata , color = 'white' , marker = '$\u25CC $' ,
282+ markeredgecolor = _color , markersize = marker_size )
283+ _ = _ax .legend (handles , labels , title = _title , loc = 'center left' , frameon = False )
284+ _ax .set_xticks ([])
285+ _ax .set_yticks ([])
286+ for item in ['top' , 'bottom' , 'left' , 'right' ]:
287+ _ax .spines [item ].set_visible (False )
288+
289+ fig = self .figure
290+ if isinstance (sct , mpl .collections .PathCollection ):
291+ ax = fig .add_subplot (gs )
292+ __draw_legend_core (ax , sct , size_factor , title , circle , color )
293+ else :
294+ new_gs = gridspec .GridSpecFromSubplotSpec (1 , len (sct ), wspace = .5 , subplot_spec = gs )
295+ for i , (_sct , _color , _title ) in enumerate (zip (sct , color , title )):
296+ ax = fig .add_subplot (new_gs [0 , i ])
297+ ax .set_facecolor ((0 , 0 , 0 , 0 ))
298+ __draw_legend_core (ax , _sct , size_factor , _title , circle , _color )
261299
262300 def __cluster_matrix (self , axis = 0 , ** kwargs ):
263301 from .hierarchical import cluster_hierarchy
@@ -318,25 +356,9 @@ def plot(self, size_factor: float = 15,
318356 self .__preprocess_data (size_factor , cluster_row = cluster_row , cluster_col = cluster_col ,
319357 ** cluster_kws if cluster_kws is not None else {}
320358 )
321- ax , ax_cbar , ax_sizes , ax_circles , ax_row_bands , ax_col_bands , fig = self .__get_figure ()
322- scatter , sct_circle = self .__draw_dotplot (ax , cmap , vmin , vmax , ** kwargs )
323- # todo
324- if isinstance (scatter , Sequence ):
325- pass
326- else :
327- self .__draw_legend (ax_sizes , scatter , size_factor ,
328- color = kwargs .get ('dot_color' , '#58000C' ), # dot legend color
329- title = kwargs .get ('dot_title' , 'Sizes' ))
330- # todo
331- if self .color_data is not None :
332- self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax ,
333- ylabel = kwargs .get ('colorbar_title' , '-log10(pvalue)' ))
334- if sct_circle is not None :
335- self .__draw_legend (ax_circles , sct_circle , size_factor ,
336- color = kwargs .get ('circle_color' , '#000000' ),
337- title = kwargs .get ('circle_title' , 'Circles' ),
338- circle = True )
339-
359+ ax , gs_cbar_legend , gs_sizes_legend , gs_circles_legend , ax_row_bands , ax_col_bands , fig = self .__get_figure ()
360+ self .__draw_dotplot (ax , cmap , vmin , vmax , size_factor , gs_cbar = gs_cbar_legend , gs_sizes = gs_sizes_legend ,
361+ gs_circles = gs_circles_legend , ** kwargs )
340362 if self .col_colors is not None :
341363 from .annotation_bands import draw_heatmap
342364 color_band_kws = {} if color_band_kws is None else color_band_kws
@@ -368,7 +390,7 @@ def resolve_mask(mask_dataframe: pd.DataFrame):
368390 group_masks .append (group_mask )
369391 if n_group < 2 :
370392 raise ValueError ('group number<2' )
371- return group_masks , n_group
393+ return group_masks , n_group , groups
372394
373395 def __str__ (self ):
374396 return 'DotPlot object with data point in shape %s' % str (self .size_data .shape )
0 commit comments