Skip to content

Commit 100c317

Browse files
committed
update calc of figure size
1 parent 55deb46 commit 100c317

1 file changed

Lines changed: 33 additions & 28 deletions

File tree

dotplot/__init__.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111

1212
class DotPlot(object):
13+
DEFAULT_ITEM_HEIGHT = 0.3
14+
DEFAULT_ITEM_WIDTH = 0.35
15+
DEFAULT_LEGENDS_WIDTH = .5
16+
MIN_FIGURE_HEIGHT = 3
17+
1318
def __init__(self, df_size: pd.DataFrame,
1419
df_color: Union[pd.DataFrame, None] = None,
1520
):
@@ -19,14 +24,31 @@ def __init__(self, df_size: pd.DataFrame,
1924
:param df_size: the DataFrame object represents the scatter size in dotplot
2025
:param df_color: the DataFrame object represents the color in dotplot
2126
"""
22-
__slots__ = ['size_data', 'color_data', 'height', 'width', 'resized_size_data']
27+
__slots__ = ['size_data', 'color_data', 'height_item', 'width_item', 'resized_size_data']
2328
if (df_color is not None) & (df_size.shape != df_color.shape):
2429
raise ValueError('df_size and df_color should have the same dimension')
2530
self.size_data = df_size
2631
self.color_data = df_color
27-
self.height, self.width = df_size.shape
32+
self.height_item, self.width_item = df_size.shape
2833
self.resized_size_data: pd.DataFrame
2934

35+
def __get_figure(self):
36+
_text_max = math.ceil(self.size_data.index.map(len).max() / 15)
37+
mainplot_height = self.height_item * self.DEFAULT_ITEM_HEIGHT
38+
mainplot_width = (
39+
(_text_max + self.width_item) * self.DEFAULT_ITEM_WIDTH
40+
)
41+
figure_height = max([self.MIN_FIGURE_HEIGHT, mainplot_height])
42+
figure_width = mainplot_width + self.DEFAULT_LEGENDS_WIDTH
43+
plt.style.use('seaborn-white')
44+
fig = plt.figure(figsize=(figure_width, figure_height))
45+
gs = gridspec.GridSpec(nrows=2, ncols=2, wspace=0.15, hspace=0.15,
46+
width_ratios=[mainplot_width, self.DEFAULT_LEGENDS_WIDTH])
47+
ax = fig.add_subplot(gs[:, 0])
48+
ax_cbar = fig.add_subplot(gs[1, 1])
49+
ax_legend = fig.add_subplot(gs[0, 1])
50+
return ax, ax_cbar, ax_legend, fig
51+
3052
@classmethod
3153
def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str, sizes_key: str,
3254
color_key: str, selected_item: Union[None, Sequence] = None,
@@ -50,11 +72,14 @@ class method for conveniently constructing DotPlot from tidy data
5072
:return:
5173
"""
5274
data_frame = data_frame[[item_key, group_key, sizes_key, color_key]]
75+
_original_item_order = data_frame[item_key].tolist()
76+
_original_item_order = _original_item_order[::-1]
5377
if sizes_func is not None:
5478
data_frame[sizes_key] = data_frame[sizes_key].map(sizes_func)
5579
if color_func is not None:
5680
data_frame[color_key] = data_frame[color_key].map(color_func)
5781
data_frame = data_frame.pivot(index=item_key, columns=group_key, values=[color_key, sizes_key])
82+
data_frame = data_frame.loc[_original_item_order, :]
5883
if selected_item is not None:
5984
data_frame = data_frame.loc[selected_item, :]
6085
if selected_group is not None:
@@ -68,26 +93,9 @@ class method for conveniently constructing DotPlot from tidy data
6893
sizes_df.columns = sizes_df.columns.map(lambda x: '_'.join(x.split('_')[1:]))
6994
return cls(color_df, sizes_df)
7095

71-
def __determine_figsize(self, **kwargs):
72-
width_factor = kwargs.get('width_factor', 4)
73-
height_factor = kwargs.get('height_factor', 0.45)
74-
fig_width, fig_height = width_factor * self.width, height_factor * self.height
75-
fig_width = fig_width / 9 * 10
76-
return fig_width, fig_height
77-
78-
def __get_figure_layout(self, **kwargs):
79-
fig_width, fig_height = self.__determine_figsize(**kwargs)
80-
plt.style.use('seaborn-white')
81-
fig = plt.figure(figsize=(fig_width, fig_height))
82-
gs = gridspec.GridSpec(nrows=2, ncols=10, wspace=0.4, hspace=0.1)
83-
ax = fig.add_subplot(gs[:, :-4])
84-
ax_cbar = fig.add_subplot(gs[1, -4:-3])
85-
ax_legend = fig.add_subplot(gs[0, -4:])
86-
return ax, ax_cbar, ax_legend, fig
87-
8896
def __get_coordinates(self, size_factor):
89-
X = list(range(1, self.width + 1)) * self.height
90-
Y = sorted(list(range(1, self.height + 1)) * self.width)
97+
X = list(range(1, self.width_item + 1)) * self.height_item
98+
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
9199
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
92100
return X, Y
93101

@@ -99,7 +107,7 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
99107
else:
100108
sct = ax.scatter(X, Y, c=self.color_data.values.flatten(), s=self.resized_size_data.values.flatten(),
101109
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap)
102-
width, height = self.width, self.height
110+
width, height = self.width_item, self.height_item
103111
ax.set_xlim([0.5, width + 0.5])
104112
ax.set_ylim([0.6, height + 0.6])
105113
ax.set_xticks(range(1, width + 1))
@@ -147,25 +155,22 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor):
147155
def plot(self, size_factor: float = 15,
148156
vmin: float = 0, vmax: float = None,
149157
path: Union[PathLike, None] = None,
150-
cmap: Union[str, mpl.colors.Colormap] = 'Reds',
151-
**kwargs):
158+
cmap: Union[str, mpl.colors.Colormap] = 'Reds'):
152159
"""
153160
154161
:param size_factor: `size factor` * `value` for the actually representation of scatter size in the final figure
155162
:param vmin: `vmin` in `matplotlib.pyplot.scatter`
156163
:param vmax: `vmax` in `matplotlib.pyplot.scatter`
157164
:param path: path to save the figure
158165
:param cmap: color map supported by matplotlib
159-
:param kwargs:
160166
:return:
161167
"""
162-
ax, ax_cbar, ax_legend, fig = self.__get_figure_layout(**kwargs)
168+
ax, ax_cbar, ax_legend, fig = self.__get_figure()
163169
scatter = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
164170
self.__draw_legend(ax_legend, scatter, size_factor)
165171
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax)
166-
plt.subplots_adjust(left=0.75)
167172
if path:
168-
fig.savefig(path, dpi=300)
173+
fig.savefig(path, dpi=300, bbox_inches='tight') #
169174
return scatter
170175

171176
def __str__(self):

0 commit comments

Comments
 (0)