@@ -17,13 +17,13 @@ class DotPlot(object):
1717 DEFAULT_ITEM_WIDTH = 0.3
1818 DEFAULT_LEGENDS_WIDTH = .45
1919 MIN_FIGURE_HEIGHT = 3
20- DEFAULT_BAND_ITEM_LENGTH = DEFAULT_ITEM_HEIGHT
20+ DEFAULT_BAND_ITEM_LENGTH = .2
2121
22- # TODO implement annotation band
2322 def __init__ (self , df_size : pd .DataFrame ,
2423 df_color : Union [pd .DataFrame , None ] = None ,
2524 df_circle : Union [pd .DataFrame , None ] = None ,
26- df_annotation : Union [pd .DataFrame , None ] = None ,
25+ row_colors : Union [pd .DataFrame , None ] = None ,
26+ col_colors : Union [pd .DataFrame , None ] = None ,
2727 ):
2828 """
2929 Construction a `DotPlot` object from `df_size` and `df_color`
@@ -33,46 +33,71 @@ def __init__(self, df_size: pd.DataFrame,
3333 """
3434 __slots__ = ['size_data' , 'resized_size_data' ,
3535 'color_data' , 'height_item' , 'width_item' ,
36- 'circle_data' , 'resized_circle_data' , 'annotation_data '
36+ 'circle_data' , 'resized_circle_data' , 'row_colors' , 'col_colors '
3737 ]
3838 if df_color is not None and df_size .shape != df_color .shape :
3939 raise ValueError ('df_size and df_color should have the same dimension' )
4040 if df_circle is not None and df_size .shape != df_circle .shape :
4141 raise ValueError ('df_size and df_circle should have the same dimension' )
42- if df_annotation is not None and df_size .shape != df_annotation .shape :
43- raise ValueError ('df_size and df_annotation should have the same row number' )
42+ if row_colors is not None and df_size .shape [0 ] != len (row_colors ):
43+ raise ValueError ('row_colors has the wrong shape' )
44+ if col_colors is not None and df_size .shape [1 ] != len (col_colors ):
45+ raise ValueError ('col_colors has the wrong shape' )
4446
4547 self .size_data = df_size
4648 self .color_data = df_color
4749 self .circle_data = df_circle
4850 self .height_item , self .width_item = df_size .shape
49- self .annotation_data = df_annotation
51+ # TODO code logic need to argument
52+ self .row_colors = row_colors
53+ self .col_colors = col_colors
5054 self .resized_size_data : Union [pd .DataFrame , None ] = None
5155 self .resized_circle_data : Union [pd .DataFrame , None ] = None
5256
5357 def __get_figure (self ):
58+ """
59+ Figure layout
60+ :return:
61+ """
5462 _text_max = math .ceil (self .size_data .index .map (len ).max () / 15 )
5563 mainplot_height = self .height_item * self .DEFAULT_ITEM_HEIGHT
5664 mainplot_width = (
5765 (_text_max + self .width_item ) * self .DEFAULT_ITEM_WIDTH
5866 )
5967 figure_height = max ([self .MIN_FIGURE_HEIGHT , mainplot_height ])
6068 figure_width = mainplot_width + self .DEFAULT_LEGENDS_WIDTH
61- if self .annotation_data is not None :
62- # figure_width = figure_width + self.DEFAULT_BAND_ITEM_LENGTH * self.annotation_data.shape[1]
63- ...
69+ band_width , band_height = 0. , 0.
70+ if self .row_colors is not None :
71+ band_width = self .DEFAULT_BAND_ITEM_LENGTH * self .row_colors .shape [1 ]
72+ if self .col_colors is not None :
73+ band_height = self .DEFAULT_BAND_ITEM_LENGTH * self .col_colors .shape [1 ]
74+ figure_width = figure_width + band_width
75+ figure_height = figure_height + band_height
76+
6477 plt .style .use ('seaborn-white' )
6578 fig = plt .figure (figsize = (figure_width , figure_height ))
66- gs = gridspec .GridSpec (nrows = 3 , ncols = 2 , wspace = 0.15 , hspace = 0.15 ,
67- width_ratios = [mainplot_width , self .DEFAULT_LEGENDS_WIDTH ])
68- ax = fig .add_subplot (gs [:, 0 ])
69- ax_cbar = fig .add_subplot (gs [2 , 1 ])
70- ax_sizes = fig .add_subplot (gs [0 , 1 ])
71- ax_circles = fig .add_subplot (gs [1 , 1 ])
79+ gs = gridspec .GridSpec (nrows = 2 , ncols = 3 , wspace = 0.05 , hspace = 0.02 ,
80+ width_ratios = [mainplot_width , band_width , self .DEFAULT_LEGENDS_WIDTH ],
81+ height_ratios = [band_height , mainplot_height ]
82+ )
83+ ax = fig .add_subplot (gs [1 , 0 ])
84+ ax_row_bands = fig .add_subplot (gs [1 , 1 ])
85+ ax_col_bands = fig .add_subplot (gs [0 , 0 ])
86+ ax_abandon = fig .add_subplot (gs [0 , 1 ])
87+ legend_gs = gridspec .GridSpecFromSubplotSpec (3 , 1 , hspace = .1 , subplot_spec = gs [1 , 2 ])
88+ ax_sizes = fig .add_subplot (legend_gs [0 , 0 ])
89+ ax_circles = fig .add_subplot (legend_gs [1 , 0 ])
90+ ax_cbar = fig .add_subplot (legend_gs [2 , 0 ])
91+
92+ _ , _ = ax_sizes .axis ('off' ), ax_circles .axis ('off' )
93+ if self .col_colors is None :
94+ ax_col_bands .axis ('off' )
95+ if self .row_colors is None :
96+ ax_row_bands .axis ('off' )
7297 if self .color_data is None :
7398 ax_cbar .axis ('off' )
74- ax_circles .axis ('off' )
75- return ax , ax_cbar , ax_sizes , ax_circles , fig
99+ ax_abandon .axis ('off' )
100+ return ax , ax_cbar , ax_sizes , ax_circles , ax_row_bands , ax_col_bands , fig
76101
77102 @classmethod
78103 def parse_from_tidy_data (cls , data_frame : pd .DataFrame , item_key : str , group_key : str , sizes_key : str ,
@@ -126,7 +151,7 @@ class method for conveniently constructing DotPlot from tidy data
126151 circle_df = data_frame .loc [:, data_frame .columns .str .startswith (circle_key )]
127152 return cls (sizes_df , color_df , circle_df )
128153
129- def __get_coordinates (self , size_factor ):
154+ def __get_coordinates (self ):
130155 X = list (range (1 , self .width_item + 1 )) * self .height_item
131156 Y = sorted (list (range (1 , self .height_item + 1 )) * self .width_item )
132157 return X , Y
@@ -138,7 +163,7 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax, **kws):
138163 for _value in ['dot_title' , 'circle_title' , 'colorbar_title' , 'dot_color' , 'circle_color' ]:
139164 _ = kws .pop (_value , None )
140165
141- X , Y = self .__get_coordinates (size_factor )
166+ X , Y = self .__get_coordinates ()
142167 if self .color_data is None :
143168 sct = ax .scatter (X , Y , c = dot_color , s = self .resized_size_data .values .flatten (),
144169 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap , ** kws )
@@ -234,7 +259,9 @@ def plot(self, size_factor: float = 15,
234259 path : Union [PathLike , None ] = None ,
235260 cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ,
236261 cluster_row : bool = False , cluster_col : bool = False ,
237- cluster_kws : Union [Dict , None ] = None , ** kwargs
262+ cluster_kws : Union [Dict , None ] = None ,
263+ color_band_kws : Union [Dict , None ] = None ,
264+ ** kwargs
238265 ):
239266 """
240267
@@ -248,12 +275,13 @@ def plot(self, size_factor: float = 15,
248275 :param cluster_kws, key args for cluster, including `cluster_method`, `cluster_metric`, 'cluster_n'
249276 :param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
250277 other kwargs are passed to `matplotlib.Axes.scatter`
278+ :param color_band_kws: this kwargs was passed to `matplotlib.axes.Axes.pcolormesh`
251279 :return:
252280 """
253281 self .__preprocess_data (size_factor , cluster_row = cluster_row , cluster_col = cluster_col ,
254282 ** cluster_kws if cluster_kws is not None else {}
255283 )
256- ax , ax_cbar , ax_sizes , ax_circles , fig = self .__get_figure ()
284+ ax , ax_cbar , ax_sizes , ax_circles , ax_row_bands , ax_col_bands , fig = self .__get_figure ()
257285 scatter , sct_circle = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
258286 self .__draw_legend (ax_sizes , scatter , size_factor ,
259287 color = kwargs .get ('dot_color' , '#58000C' ), # dot legend color
@@ -266,8 +294,20 @@ def plot(self, size_factor: float = 15,
266294 if self .color_data is not None :
267295 self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax ,
268296 ylabel = kwargs .get ('colorbar_title' , '-log10(pvalue)' ))
297+
298+ if self .col_colors is not None :
299+ from .annotation_bands import draw_heatmap
300+ color_band_kws = {} if color_band_kws is None else color_band_kws
301+ draw_heatmap (self .col_colors , axes = ax_col_bands ,
302+ index_order = self .size_data .columns .tolist (), axis = 1 , ** color_band_kws )
303+ if self .row_colors is not None :
304+ color_band_kws = {} if color_band_kws is None else color_band_kws
305+ from .annotation_bands import draw_heatmap
306+ draw_heatmap (self .row_colors , axes = ax_row_bands ,
307+ index_order = self .size_data .index .tolist (), axis = 0 , ** color_band_kws )
308+
269309 if path :
270- fig .savefig (path , dpi = 300 , bbox_inches = 'tight' ) #
310+ fig .savefig (path , dpi = 300 , bbox_inches = 'tight' )
271311 return scatter
272312
273313 def __str__ (self ):
0 commit comments