88from matplotlib import gridspec
99from matplotlib import pyplot as plt
1010
11+ mpl .rcParams ['pdf.fonttype' ] = 42
12+ mpl .rcParams ["font.sans-serif" ] = "Arial"
13+
1114
1215class DotPlot (object ):
1316 DEFAULT_ITEM_HEIGHT = 0.3
@@ -17,20 +20,26 @@ class DotPlot(object):
1720
1821 def __init__ (self , df_size : pd .DataFrame ,
1922 df_color : Union [pd .DataFrame , None ] = None ,
23+ df_circle : Union [pd .DataFrame , None ] = None ,
2024 ):
2125 """
2226 Construction a `DotPlot` object from `df_size` and `df_color`
2327
2428 :param df_size: the DataFrame object represents the scatter size in dotplot
2529 :param df_color: the DataFrame object represents the color in dotplot
2630 """
27- __slots__ = ['size_data' , 'color_data' , 'height_item' , 'width_item' , 'resized_size_data' ]
31+ __slots__ = ['size_data' , 'resized_size_data' ,
32+ 'color_data' , 'height_item' , 'width_item' ,
33+ 'circle_data' , 'resized_circle_data'
34+ ]
2835 if (df_color is not None ) & (df_size .shape != df_color .shape ):
2936 raise ValueError ('df_size and df_color should have the same dimension' )
3037 self .size_data = df_size
3138 self .color_data = df_color
39+ self .circle_data = df_circle
3240 self .height_item , self .width_item = df_size .shape
3341 self .resized_size_data : pd .DataFrame
42+ self .resized_circle_data : pd .DataFrame
3443
3544 def __get_figure (self ):
3645 _text_max = math .ceil (self .size_data .index .map (len ).max () / 15 )
@@ -46,12 +55,14 @@ def __get_figure(self):
4655 width_ratios = [mainplot_width , self .DEFAULT_LEGENDS_WIDTH ])
4756 ax = fig .add_subplot (gs [:, 0 ])
4857 ax_cbar = fig .add_subplot (gs [2 , 1 ])
49- ax_legend = fig .add_subplot (gs [0 :2 , 1 ])
50- return ax , ax_cbar , ax_legend , fig
58+ ax_sizes = fig .add_subplot (gs [0 , 1 ])
59+ ax_circles = fig .add_subplot (gs [1 , 1 ])
60+ return ax , ax_cbar , ax_sizes , ax_circles , fig
5161
5262 @classmethod
5363 def parse_from_tidy_data (cls , data_frame : pd .DataFrame , item_key : str , group_key : str , sizes_key : str ,
54- color_key : str , selected_item : Union [None , Sequence ] = None ,
64+ color_key : Union [None , str ] = None , circle_key : Union [None , str ] = None ,
65+ selected_item : Union [None , Sequence ] = None ,
5566 selected_group : Union [None , Sequence ] = None , * ,
5667 sizes_func : Union [None , Callable ] = None , color_func : Union [None , Callable ] = None
5768 ):
@@ -69,34 +80,41 @@ class method for conveniently constructing DotPlot from tidy data
6980 :param selected_group: Same as `selected_item`, for group order and subset groups
7081 :param sizes_func:
7182 :param color_func:
83+ :param circle_key:
7284 :return:
7385 """
74- data_frame = data_frame [[item_key , group_key , sizes_key , color_key ]]
86+ keys = [v for v in [item_key , group_key , sizes_key , color_key , circle_key ] if v is not None ]
87+ data_frame = data_frame [keys ]
7588 _original_item_order = data_frame [item_key ].tolist ()
7689 _original_item_order = _original_item_order [::- 1 ]
7790 if sizes_func is not None :
7891 data_frame [sizes_key ] = data_frame [sizes_key ].map (sizes_func )
7992 if color_func is not None :
8093 data_frame [color_key ] = data_frame [color_key ].map (color_func )
81- data_frame = data_frame .pivot (index = item_key , columns = group_key , values = [color_key , sizes_key ])
94+ keys .remove (item_key )
95+ keys .remove (group_key )
96+ data_frame = data_frame .pivot (index = item_key , columns = group_key , values = keys )
8297 data_frame = data_frame .loc [_original_item_order , :]
8398 if selected_item is not None :
8499 data_frame = data_frame .loc [selected_item , :]
85100 if selected_group is not None :
86101 data_frame = data_frame .loc [:, selected_group ]
87-
88102 data_frame .columns = data_frame .columns .map (lambda x : '_' .join (x ))
89103 data_frame = data_frame .fillna (0 )
90- color_df = data_frame .loc [:, data_frame .columns .str .startswith (color_key )]
104+
105+ sizes_df , color_df , circle_df = (None , None , None )
91106 sizes_df = data_frame .loc [:, data_frame .columns .str .startswith (sizes_key )]
92- color_df .columns = color_df .columns .map (lambda x : '_' .join (x .split ('_' )[1 :]))
93- sizes_df .columns = sizes_df .columns .map (lambda x : '_' .join (x .split ('_' )[1 :]))
94- return cls (color_df , sizes_df )
107+ if color_key is not None :
108+ color_df = data_frame .loc [:, data_frame .columns .str .startswith (color_key )]
109+ if circle_key is not None :
110+ circle_df = data_frame .loc [:, data_frame .columns .str .startswith (circle_key )]
111+ return cls (sizes_df , color_df , circle_df )
95112
96113 def __get_coordinates (self , size_factor ):
97114 X = list (range (1 , self .width_item + 1 )) * self .height_item
98115 Y = sorted (list (range (1 , self .height_item + 1 )) * self .width_item )
99116 self .resized_size_data = self .size_data .applymap (func = lambda x : x * size_factor )
117+ self .resized_circle_data = self .circle_data .applymap (func = lambda x : x * size_factor )
100118 return X , Y
101119
102120 def __draw_dotplot (self , ax , size_factor , cmap , vmin , vmax ):
@@ -107,6 +125,10 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
107125 else :
108126 sct = ax .scatter (X , Y , c = self .color_data .values .flatten (), s = self .resized_size_data .values .flatten (),
109127 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap )
128+ sct_circle = None
129+ if self .circle_data is not None :
130+ sct_circle = ax .scatter (X , Y , c = '' , edgecolors = 'k' , marker = 'o' , linestyle = '--' ,
131+ s = self .resized_circle_data .values .flatten ())
110132 width , height = self .width_item , self .height_item
111133 ax .set_xlim ([0.5 , width + 0.5 ])
112134 ax .set_ylim ([0.6 , height + 0.6 ])
@@ -116,13 +138,13 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
116138 ax .set_yticklabels (self .size_data .index .tolist ())
117139 ax .tick_params (axis = 'y' , length = 5 , labelsize = 15 , direction = 'out' )
118140 ax .tick_params (axis = 'x' , length = 5 , labelsize = 15 , direction = 'out' )
119- return sct
141+ return sct , sct_circle
120142
121143 @staticmethod
122144 def __draw_color_bar (ax , sct : mpl .collections .PathCollection , cmap , vmin , vmax ):
123145 gradient = np .linspace (1 , 0 , 500 )
124146 gradient = gradient [:, np .newaxis ]
125- im = ax .imshow (gradient , aspect = 'auto' , cmap = cmap , origin = 'upper' , extent = [.2 , 0.3 , 0.5 , - 0.5 ])
147+ _ = ax .imshow (gradient , aspect = 'auto' , cmap = cmap , origin = 'upper' , extent = [.2 , 0.3 , 0.5 , - 0.5 ])
126148 ax .set_xticks ([])
127149 ax .set_yticks ([])
128150 ax_cbar2 = ax .twinx ()
@@ -135,16 +157,25 @@ def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax):
135157 _ = ax_cbar2 .set_ylabel ('-log10(pvalue)' )
136158
137159 @staticmethod
138- def __draw_legend (ax , sct : mpl .collections .PathCollection , size_factor ):
160+ def __draw_legend (ax , sct : mpl .collections .PathCollection , size_factor , title , circle = False , color = None ):
161+ print (id (sct ))
139162 handles , labels = sct .legend_elements (prop = "sizes" , alpha = 1 ,
140163 func = lambda x : x / size_factor ,
141- color = '#58000C' )
164+ color = color
165+ )
142166 if len (handles ) > 3 :
143167 handles = np .asarray (handles )
144168 labels = np .asarray (labels )
145169 handles = handles [[0 , math .ceil (len (handles ) / 2 ), - 1 ]]
146170 labels = labels [[0 , math .ceil (len (labels ) / 2 ), - 1 ]]
147- _ = ax .legend (handles , labels , title = "Sizes" , loc = 'center left' ) # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
171+ if circle :
172+ from matplotlib .lines import Line2D
173+ for i , _item in enumerate (handles ):
174+ xdata , ydata = _item .get_data ()
175+ marker_size = _item .get_markersize ()
176+ handles [i ] = Line2D (xdata , ydata , color = 'white' , marker = '$\u25CC $' ,
177+ markeredgecolor = color , markersize = marker_size )
178+ _ = ax .legend (handles , labels , title = title , loc = 'center left' ) # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
148179 ax .set_xticks ([])
149180 ax .set_yticks ([])
150181 ax .spines ['top' ].set_visible (False )
@@ -165,9 +196,11 @@ def plot(self, size_factor: float = 15,
165196 :param cmap: color map supported by matplotlib
166197 :return:
167198 """
168- ax , ax_cbar , ax_legend , fig = self .__get_figure ()
169- scatter = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
170- self .__draw_legend (ax_legend , scatter , size_factor )
199+ ax , ax_cbar , ax_sizes , ax_circles , fig = self .__get_figure ()
200+ scatter , sct_circle = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
201+ self .__draw_legend (ax_sizes , scatter , size_factor , title = 'Sizes' , color = '#58000C' )
202+ if sct_circle is not None :
203+ self .__draw_legend (ax_circles , sct_circle , size_factor , title = 'Circles' , circle = True , color = 'k' )
171204 self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax )
172205 if path :
173206 fig .savefig (path , dpi = 300 , bbox_inches = 'tight' ) #
0 commit comments