1010
1111
1212class DotPlot (object ):
13+ DEFAULT_ITEM_HEIGHT = 0.3
14+ DEFAULT_ITEM_WIDTH = 0.35
15+ DEFAULT_LEGENDS_WIDTH = .5
16+ MIN_FIGURE_HEIGHT = 3
17+
1318 def __init__ (self , df_size : pd .DataFrame ,
1419 df_color : Union [pd .DataFrame , None ] = None ,
1520 ):
@@ -19,14 +24,31 @@ def __init__(self, df_size: pd.DataFrame,
1924 :param df_size: the DataFrame object represents the scatter size in dotplot
2025 :param df_color: the DataFrame object represents the color in dotplot
2126 """
22- __slots__ = ['size_data' , 'color_data' , 'height ' , 'width ' , 'resized_size_data' ]
27+ __slots__ = ['size_data' , 'color_data' , 'height_item ' , 'width_item ' , 'resized_size_data' ]
2328 if (df_color is not None ) & (df_size .shape != df_color .shape ):
2429 raise ValueError ('df_size and df_color should have the same dimension' )
2530 self .size_data = df_size
2631 self .color_data = df_color
27- self .height , self .width = df_size .shape
32+ self .height_item , self .width_item = df_size .shape
2833 self .resized_size_data : pd .DataFrame
2934
35+ def __get_figure (self ):
36+ _text_max = math .ceil (self .size_data .index .map (len ).max () / 15 )
37+ mainplot_height = self .height_item * self .DEFAULT_ITEM_HEIGHT
38+ mainplot_width = (
39+ (_text_max + self .width_item ) * self .DEFAULT_ITEM_WIDTH
40+ )
41+ figure_height = max ([self .MIN_FIGURE_HEIGHT , mainplot_height ])
42+ figure_width = mainplot_width + self .DEFAULT_LEGENDS_WIDTH
43+ plt .style .use ('seaborn-white' )
44+ fig = plt .figure (figsize = (figure_width , figure_height ))
45+ gs = gridspec .GridSpec (nrows = 2 , ncols = 2 , wspace = 0.15 , hspace = 0.15 ,
46+ width_ratios = [mainplot_width , self .DEFAULT_LEGENDS_WIDTH ])
47+ ax = fig .add_subplot (gs [:, 0 ])
48+ ax_cbar = fig .add_subplot (gs [1 , 1 ])
49+ ax_legend = fig .add_subplot (gs [0 , 1 ])
50+ return ax , ax_cbar , ax_legend , fig
51+
3052 @classmethod
3153 def parse_from_tidy_data (cls , data_frame : pd .DataFrame , item_key : str , group_key : str , sizes_key : str ,
3254 color_key : str , selected_item : Union [None , Sequence ] = None ,
@@ -50,11 +72,14 @@ class method for conveniently constructing DotPlot from tidy data
5072 :return:
5173 """
5274 data_frame = data_frame [[item_key , group_key , sizes_key , color_key ]]
75+ _original_item_order = data_frame [item_key ].tolist ()
76+ _original_item_order = _original_item_order [::- 1 ]
5377 if sizes_func is not None :
5478 data_frame [sizes_key ] = data_frame [sizes_key ].map (sizes_func )
5579 if color_func is not None :
5680 data_frame [color_key ] = data_frame [color_key ].map (color_func )
5781 data_frame = data_frame .pivot (index = item_key , columns = group_key , values = [color_key , sizes_key ])
82+ data_frame = data_frame .loc [_original_item_order , :]
5883 if selected_item is not None :
5984 data_frame = data_frame .loc [selected_item , :]
6085 if selected_group is not None :
@@ -68,26 +93,9 @@ class method for conveniently constructing DotPlot from tidy data
6893 sizes_df .columns = sizes_df .columns .map (lambda x : '_' .join (x .split ('_' )[1 :]))
6994 return cls (color_df , sizes_df )
7095
71- def __determine_figsize (self , ** kwargs ):
72- width_factor = kwargs .get ('width_factor' , 4 )
73- height_factor = kwargs .get ('height_factor' , 0.45 )
74- fig_width , fig_height = width_factor * self .width , height_factor * self .height
75- fig_width = fig_width / 9 * 10
76- return fig_width , fig_height
77-
78- def __get_figure_layout (self , ** kwargs ):
79- fig_width , fig_height = self .__determine_figsize (** kwargs )
80- plt .style .use ('seaborn-white' )
81- fig = plt .figure (figsize = (fig_width , fig_height ))
82- gs = gridspec .GridSpec (nrows = 2 , ncols = 10 , wspace = 0.4 , hspace = 0.1 )
83- ax = fig .add_subplot (gs [:, :- 4 ])
84- ax_cbar = fig .add_subplot (gs [1 , - 4 :- 3 ])
85- ax_legend = fig .add_subplot (gs [0 , - 4 :])
86- return ax , ax_cbar , ax_legend , fig
87-
8896 def __get_coordinates (self , size_factor ):
89- X = list (range (1 , self .width + 1 )) * self .height
90- Y = sorted (list (range (1 , self .height + 1 )) * self .width )
97+ X = list (range (1 , self .width_item + 1 )) * self .height_item
98+ Y = sorted (list (range (1 , self .height_item + 1 )) * self .width_item )
9199 self .resized_size_data = self .size_data .applymap (func = lambda x : x * size_factor )
92100 return X , Y
93101
@@ -99,7 +107,7 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
99107 else :
100108 sct = ax .scatter (X , Y , c = self .color_data .values .flatten (), s = self .resized_size_data .values .flatten (),
101109 edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap )
102- width , height = self .width , self .height
110+ width , height = self .width_item , self .height_item
103111 ax .set_xlim ([0.5 , width + 0.5 ])
104112 ax .set_ylim ([0.6 , height + 0.6 ])
105113 ax .set_xticks (range (1 , width + 1 ))
@@ -147,25 +155,22 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor):
147155 def plot (self , size_factor : float = 15 ,
148156 vmin : float = 0 , vmax : float = None ,
149157 path : Union [PathLike , None ] = None ,
150- cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ,
151- ** kwargs ):
158+ cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ):
152159 """
153160
154161 :param size_factor: `size factor` * `value` for the actually representation of scatter size in the final figure
155162 :param vmin: `vmin` in `matplotlib.pyplot.scatter`
156163 :param vmax: `vmax` in `matplotlib.pyplot.scatter`
157164 :param path: path to save the figure
158165 :param cmap: color map supported by matplotlib
159- :param kwargs:
160166 :return:
161167 """
162- ax , ax_cbar , ax_legend , fig = self .__get_figure_layout ( ** kwargs )
168+ ax , ax_cbar , ax_legend , fig = self .__get_figure ( )
163169 scatter = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
164170 self .__draw_legend (ax_legend , scatter , size_factor )
165171 self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax )
166- plt .subplots_adjust (left = 0.75 )
167172 if path :
168- fig .savefig (path , dpi = 300 )
173+ fig .savefig (path , dpi = 300 , bbox_inches = 'tight' ) #
169174 return scatter
170175
171176 def __str__ (self ):
0 commit comments