11import math
22from os import PathLike
3- from typing import Union , Sequence , Callable , Dict
3+ from typing import Union , Sequence , Callable , Dict , List
44
55import matplotlib as mpl
66import numpy as np
1111mpl .rcParams ['pdf.fonttype' ] = 42
1212mpl .rcParams ["font.sans-serif" ] = "Arial"
1313
14+ CMAPS_PRESET = ('Reds' , 'Blues' , 'Purples' , 'Oranges' , 'Greens' , 'Greys' )
15+ COLORS_PRESET = ('r' , 'b' , '#BA55D3' , '#FFA500' , 'g' , '#C0C0C0' )
16+
1417
1518class DotPlot (object ):
1619 DEFAULT_ITEM_HEIGHT = 0.3
@@ -24,6 +27,7 @@ def __init__(self, df_size: pd.DataFrame,
2427 df_circle : Union [pd .DataFrame , None ] = None ,
2528 row_colors : Union [pd .DataFrame , None ] = None ,
2629 col_colors : Union [pd .DataFrame , None ] = None ,
30+ mask_frames : Union [pd .DataFrame , None ] = None
2731 ):
2832 """
2933 Construction a `DotPlot` object from `df_size` and `df_color`
@@ -33,7 +37,7 @@ def __init__(self, df_size: pd.DataFrame,
3337 """
3438 __slots__ = ['size_data' , 'resized_size_data' ,
3539 'color_data' , 'height_item' , 'width_item' ,
36- 'circle_data' , 'resized_circle_data' , 'row_colors' , 'col_colors'
40+ 'circle_data' , 'resized_circle_data' , 'row_colors' , 'col_colors' , 'mask_frames'
3741 ]
3842 if df_color is not None and df_size .shape != df_color .shape :
3943 raise ValueError ('df_size and df_color should have the same dimension' )
@@ -43,6 +47,8 @@ def __init__(self, df_size: pd.DataFrame,
4347 raise ValueError ('row_colors has the wrong shape' )
4448 if col_colors is not None and df_size .shape [1 ] != len (col_colors ):
4549 raise ValueError ('col_colors has the wrong shape' )
50+ if mask_frames is not None and df_size .shape != mask_frames .shape :
51+ raise ValueError ('df_size and mask_frames should have the same dimension' )
4652
4753 self .size_data = df_size
4854 self .color_data = df_color
@@ -52,6 +58,7 @@ def __init__(self, df_size: pd.DataFrame,
5258 self .col_colors = col_colors
5359 self .resized_size_data : Union [pd .DataFrame , None ] = None
5460 self .resized_circle_data : Union [pd .DataFrame , None ] = None
61+ self .mask_frames = mask_frames
5562 if (self .row_colors is None ) and (self .col_colors is None ):
5663 self .col_colors = pd .DataFrame ({'' : ['#FFFFFF' ] * df_size .shape [1 ]},
5764 index = df_size .columns .tolist ())
@@ -101,6 +108,7 @@ def __get_figure(self):
101108 ax_abandon .axis ('off' )
102109 return ax , ax_cbar , ax_sizes , ax_circles , ax_row_bands , ax_col_bands , fig
103110
111+ # TODO update with the newest version of __init__
104112 @classmethod
105113 def parse_from_tidy_data (cls , data_frame : pd .DataFrame , item_key : str , group_key : str , sizes_key : str ,
106114 color_key : Union [None , str ] = None , circle_key : Union [None , str ] = None ,
@@ -158,19 +166,41 @@ def __get_coordinates(self):
158166 Y = sorted (list (range (1 , self .height_item + 1 )) * self .width_item )
159167 return X , Y
160168
161- def __draw_dotplot (self , ax , size_factor , cmap , vmin , vmax , ** kws ):
169+ def __draw_dotplot (self , ax , cmap , vmin , vmax , ** kws ):
162170 dot_color = kws .get ('dot_color' , '#58000C' )
163171 circle_color = kws .get ('circle_color' , '#000000' )
164172 kws = kws .copy ()
165173 for _value in ['dot_title' , 'circle_title' , 'colorbar_title' , 'dot_color' , 'circle_color' ]:
166174 _ = kws .pop (_value , None )
167-
168175 X , Y = self .__get_coordinates ()
169- if self .color_data is None :
170- sct = ax .scatter (X , Y , c = dot_color , s = self .resized_size_data .values .flatten (),
171- edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap , ** kws )
176+ resized_size_data_array = self .resized_size_data .values .flatten ()
177+ color_data_array_or_str = dot_color if self .color_data is None else self .color_data .values .flatten ()
178+ sct : Union [List [mpl .collections .PathCollection ], mpl .collections .PathCollection ] = []
179+ if self .mask_frames is not None :
180+ masks , n_masks = self .resolve_mask (self .mask_frames )
181+ if isinstance (color_data_array_or_str , np .ndarray ):
182+ if isinstance (cmap , (str , mpl .colors .Colormap )):
183+ cmap = CMAPS_PRESET
184+ if len (cmap ) < n_masks :
185+ raise ValueError ('too many groups to draw with limited color map' )
186+ for _ , (mask , _cmap ) in enumerate (zip (masks , cmap )):
187+ masked_resized_size_data_array = np .ma .masked_array (resized_size_data_array , mask = mask )
188+ _sct = ax .scatter (X , Y , c = color_data_array_or_str , s = masked_resized_size_data_array ,
189+ edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = _cmap , ** kws )
190+ sct .append (_sct )
191+ else :
192+ if isinstance (dot_color , str ):
193+ dot_color = COLORS_PRESET
194+ if len (dot_color ) < n_masks :
195+ raise ValueError ('too many groups to draw with limited color' )
196+ for _ , (mask , _dot_color ) in enumerate (zip (masks , dot_color )):
197+ masked_resized_size_data_array = np .ma .masked_array (resized_size_data_array , mask = mask )
198+ _sct = ax .scatter (X , Y , c = _dot_color , s = masked_resized_size_data_array ,
199+ edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , ** kws )
200+ sct .append (_sct )
201+
172202 else :
173- sct = ax .scatter (X , Y , c = self . color_data . values . flatten () , s = self . resized_size_data . values . flatten () ,
203+ sct = ax .scatter (X , Y , c = color_data_array_or_str , s = resized_size_data_array ,
174204 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap , ** kws )
175205 sct_circle = None
176206 if self .circle_data is not None :
@@ -250,7 +280,6 @@ def __cluster_matrix(self, axis=0, **kwargs):
250280 setattr (self , _obj_attr , _obj )
251281
252282 def __preprocess_data (self , size_factor , cluster_row = False , cluster_col = False , ** kwargs ):
253-
254283 if cluster_row or cluster_col :
255284 if cluster_row :
256285 self .__cluster_matrix (axis = 0 , ** kwargs )
@@ -263,7 +292,8 @@ def __preprocess_data(self, size_factor, cluster_row=False, cluster_col=False, *
263292 def plot (self , size_factor : float = 15 ,
264293 vmin : float = 0 , vmax : float = None ,
265294 path : Union [PathLike , None ] = None ,
266- cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ,
295+ cmap : Union [str , mpl .colors .Colormap ,
296+ Sequence [Union [str , mpl .colors .Colormap ]]] = 'Reds' ,
267297 cluster_row : bool = False , cluster_col : bool = False ,
268298 cluster_kws : Union [Dict , None ] = None ,
269299 color_band_kws : Union [Dict , None ] = None ,
@@ -275,31 +305,37 @@ def plot(self, size_factor: float = 15,
275305 :param vmin: `vmin` in `matplotlib.pyplot.scatter`
276306 :param vmax: `vmax` in `matplotlib.pyplot.scatter`
277307 :param path: path to save the figure
278- :param cmap: color map supported by matplotlib
308+ :param cmap: color map supported by matplotlib, can be sequence of cmap when drawing grouped dotplots
279309 :param cluster_row, whether to cluster the row
280310 :param cluster_col, whether to cluster the col
281311 :param cluster_kws, key args for cluster, including `cluster_method`, `cluster_metric`, 'cluster_n'
282312 :param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
283- other kwargs are passed to `matplotlib.Axes.scatter`
313+ other kwargs are passed to `matplotlib.Axes.scatter`. Notably, dot_color can be a
314+ color sequence when drawing grouped dotplots
284315 :param color_band_kws: this kwargs was passed to `matplotlib.axes.Axes.pcolormesh`
285316 :return:
286317 """
287318 self .__preprocess_data (size_factor , cluster_row = cluster_row , cluster_col = cluster_col ,
288319 ** cluster_kws if cluster_kws is not None else {}
289320 )
290321 ax , ax_cbar , ax_sizes , ax_circles , ax_row_bands , ax_col_bands , fig = self .__get_figure ()
291- scatter , sct_circle = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
292- self .__draw_legend (ax_sizes , scatter , size_factor ,
293- color = kwargs .get ('dot_color' , '#58000C' ), # dot legend color
294- title = kwargs .get ('dot_title' , 'Sizes' ))
322+ scatter , sct_circle = self .__draw_dotplot (ax , cmap , vmin , vmax , ** kwargs )
323+ # todo
324+ if isinstance (scatter , Sequence ):
325+ pass
326+ else :
327+ self .__draw_legend (ax_sizes , scatter , size_factor ,
328+ color = kwargs .get ('dot_color' , '#58000C' ), # dot legend color
329+ title = kwargs .get ('dot_title' , 'Sizes' ))
330+ # todo
331+ if self .color_data is not None :
332+ self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax ,
333+ ylabel = kwargs .get ('colorbar_title' , '-log10(pvalue)' ))
295334 if sct_circle is not None :
296335 self .__draw_legend (ax_circles , sct_circle , size_factor ,
297336 color = kwargs .get ('circle_color' , '#000000' ),
298337 title = kwargs .get ('circle_title' , 'Circles' ),
299338 circle = True )
300- if self .color_data is not None :
301- self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax ,
302- ylabel = kwargs .get ('colorbar_title' , '-log10(pvalue)' ))
303339
304340 if self .col_colors is not None :
305341 from .annotation_bands import draw_heatmap
@@ -311,10 +347,28 @@ def plot(self, size_factor: float = 15,
311347 from .annotation_bands import draw_heatmap
312348 draw_heatmap (self .row_colors , axes = ax_row_bands ,
313349 index_order = self .size_data .index .tolist (), axis = 0 , ** color_band_kws )
314-
315350 if path :
316351 fig .savefig (path , dpi = 300 , bbox_inches = 'tight' )
317- return scatter
352+ return fig
353+
354+ @staticmethod
355+ def resolve_mask (mask_dataframe : pd .DataFrame ):
356+ mask_dataframe = mask_dataframe .applymap (func = lambda x : str (x ))
357+ groups = np .unique (mask_dataframe .to_numpy ())
358+ mappings = dict (zip (groups , [0 ] * len (groups )))
359+ group_masks = []
360+ n_group = 0
361+ for group in groups :
362+ if group == 'nan' :
363+ continue
364+ n_group += 1
365+ _mappings = mappings .copy ()
366+ _mappings .update ({group : 1 })
367+ group_mask = mask_dataframe .applymap (func = lambda x : _mappings [x ])
368+ group_masks .append (group_mask )
369+ if n_group < 2 :
370+ raise ValueError ('group number<2' )
371+ return group_masks , n_group
318372
319373 def __str__ (self ):
320374 return 'DotPlot object with data point in shape %s' % str (self .size_data .shape )
0 commit comments