Skip to content

Commit 80cb70a

Browse files
committed
fix the bugs after adding functionality for #9, and rearrange the codes, fixed #9
1 parent 9146fae commit 80cb70a

2 files changed

Lines changed: 217 additions & 170 deletions

File tree

dotplot/core.py

Lines changed: 106 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
class DotPlot(object):
1919
DEFAULT_ITEM_HEIGHT = 0.3
2020
DEFAULT_ITEM_WIDTH = 0.3
21-
DEFAULT_LEGENDS_WIDTH = .45
22-
MIN_FIGURE_HEIGHT = 3
21+
DEFAULT_LEGENDS_WIDTH = .6
22+
MIN_FIGURE_HEIGHT = 3.5
2323
DEFAULT_BAND_ITEM_LENGTH = .2
2424

2525
def __init__(self, df_size: pd.DataFrame,
@@ -35,9 +35,9 @@ def __init__(self, df_size: pd.DataFrame,
3535
:param df_size: the DataFrame object represents the scatter size in dotplot
3636
:param df_color: the DataFrame object represents the color in dotplot
3737
"""
38-
__slots__ = ['size_data', 'resized_size_data',
39-
'color_data', 'height_item', 'width_item',
40-
'circle_data', 'resized_circle_data', 'row_colors', 'col_colors', 'mask_frames'
38+
__slots__ = ['size_data', 'resized_size_data', 'color_data', 'height_item', 'width_item',
39+
'circle_data', 'resized_circle_data', 'row_colors', 'col_colors', 'mask_frames',
40+
'figure'
4141
]
4242
if df_color is not None and df_size.shape != df_color.shape:
4343
raise ValueError('df_size and df_color should have the same dimension')
@@ -59,6 +59,7 @@ def __init__(self, df_size: pd.DataFrame,
5959
self.resized_size_data: Union[pd.DataFrame, None] = None
6060
self.resized_circle_data: Union[pd.DataFrame, None] = None
6161
self.mask_frames = mask_frames
62+
self.figure = None
6263
if (self.row_colors is None) and (self.col_colors is None):
6364
self.col_colors = pd.DataFrame({'': ['#FFFFFF'] * df_size.shape[1]},
6465
index=df_size.columns.tolist())
@@ -74,7 +75,8 @@ def __get_figure(self):
7475
(_text_max + self.width_item) * self.DEFAULT_ITEM_WIDTH
7576
)
7677
figure_height = max([self.MIN_FIGURE_HEIGHT, mainplot_height])
77-
figure_width = mainplot_width + self.DEFAULT_LEGENDS_WIDTH
78+
n_group = len(np.unique(self.mask_frames.to_numpy())) if self.mask_frames is not None else 1
79+
figure_width = mainplot_width + self.DEFAULT_LEGENDS_WIDTH * n_group
7880
band_width, band_height = 0., 0.
7981
if self.row_colors is not None:
8082
band_width = self.DEFAULT_BAND_ITEM_LENGTH * self.row_colors.shape[1]
@@ -85,28 +87,26 @@ def __get_figure(self):
8587

8688
plt.style.use('seaborn-white')
8789
fig = plt.figure(figsize=(figure_width, figure_height))
90+
self.figure = fig
8891
gs = gridspec.GridSpec(nrows=2, ncols=3, wspace=0.05, hspace=0.02,
89-
width_ratios=[mainplot_width, band_width, self.DEFAULT_LEGENDS_WIDTH],
92+
width_ratios=[mainplot_width, band_width, self.DEFAULT_LEGENDS_WIDTH * n_group],
9093
height_ratios=[band_height, mainplot_height]
9194
)
9295
ax = fig.add_subplot(gs[1, 0])
9396
ax_row_bands = fig.add_subplot(gs[1, 1])
9497
ax_col_bands = fig.add_subplot(gs[0, 0])
9598
ax_abandon = fig.add_subplot(gs[0, 1])
9699
legend_gs = gridspec.GridSpecFromSubplotSpec(3, 1, hspace=.1, subplot_spec=gs[1, 2])
97-
ax_sizes = fig.add_subplot(legend_gs[0, 0])
98-
ax_circles = fig.add_subplot(legend_gs[1, 0])
99-
ax_cbar = fig.add_subplot(legend_gs[2, 0])
100+
gs_sizes_legend = legend_gs[0, 0]
101+
gs_cbar_legend = legend_gs[2, 0]
102+
gs_circles_legend = legend_gs[1, 0]
100103

101-
_, _ = ax_sizes.axis('off'), ax_circles.axis('off')
102104
if self.col_colors is None:
103105
ax_col_bands.axis('off')
104106
if self.row_colors is None:
105107
ax_row_bands.axis('off')
106-
if self.color_data is None:
107-
ax_cbar.axis('off')
108108
ax_abandon.axis('off')
109-
return ax, ax_cbar, ax_sizes, ax_circles, ax_row_bands, ax_col_bands, fig
109+
return ax, gs_cbar_legend, gs_sizes_legend, gs_circles_legend, ax_row_bands, ax_col_bands, fig
110110

111111
# TODO update with the newest version of __init__
112112
@classmethod
@@ -166,19 +166,23 @@ def __get_coordinates(self):
166166
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
167167
return X, Y
168168

169-
def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
170-
dot_color = kws.get('dot_color', '#58000C')
171-
circle_color = kws.get('circle_color', '#000000')
172-
kws = kws.copy()
173-
for _value in ['dot_title', 'circle_title', 'colorbar_title', 'dot_color', 'circle_color']:
174-
_ = kws.pop(_value, None)
169+
def __draw_dotplot(self, ax, cmap, vmin, vmax, size_factor, *, gs_cbar: mpl.gridspec.GridSpec,
170+
gs_sizes: mpl.gridspec.GridSpec, gs_circles: mpl.gridspec.GridSpec, **kws):
175171
X, Y = self.__get_coordinates()
172+
kws = kws.copy()
173+
dot_color = kws.pop('dot_color', '#58000C')
174+
dot_title = kws.pop('dot_title', 'Sizes')
175+
circle_color = kws.pop('circle_color', '#000000')
176+
circle_title = kws.pop('circle_title', 'Circles')
177+
colorbar_title = kws.pop('colorbar_title', '-log10(Pvalue)')
178+
176179
resized_size_data_array = self.resized_size_data.values.flatten()
177180
color_data_array_or_str = dot_color if self.color_data is None else self.color_data.values.flatten()
178181
sct: Union[List[mpl.collections.PathCollection], mpl.collections.PathCollection] = []
179182
if self.mask_frames is not None:
180-
masks, n_masks = self.resolve_mask(self.mask_frames)
183+
masks, n_masks, mask_groups = self.resolve_mask(self.mask_frames)
181184
if isinstance(color_data_array_or_str, np.ndarray):
185+
vmax = np.max(self.color_data.values.flatten()) if vmax is None else vmax
182186
if isinstance(cmap, (str, mpl.colors.Colormap)):
183187
cmap = CMAPS_PRESET
184188
if len(cmap) < n_masks:
@@ -188,6 +192,8 @@ def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
188192
_sct = ax.scatter(X, Y, c=color_data_array_or_str, s=masked_resized_size_data_array,
189193
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=_cmap, **kws)
190194
sct.append(_sct)
195+
self.__draw_color_bar(gs_cbar, sct, cmap, vmin, vmax, ylabel=colorbar_title)
196+
self.__draw_legend(gs_sizes, sct[0], size_factor, color=dot_color, title=dot_title)
191197
else:
192198
if isinstance(dot_color, str):
193199
dot_color = COLORS_PRESET
@@ -198,14 +204,19 @@ def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
198204
_sct = ax.scatter(X, Y, c=_dot_color, s=masked_resized_size_data_array,
199205
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, **kws)
200206
sct.append(_sct)
201-
207+
self.__draw_legend(gs_sizes, sct, size_factor, color=dot_color, title=mask_groups)
202208
else:
209+
vmax = np.max(self.color_data.values.flatten()) if vmax is None else vmax
203210
sct = ax.scatter(X, Y, c=color_data_array_or_str, s=resized_size_data_array,
204211
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap, **kws)
205-
sct_circle = None
212+
if self.color_data is not None:
213+
self.__draw_color_bar(gs_cbar, sct, cmap, vmin, vmax, ylabel=colorbar_title)
214+
self.__draw_legend(gs_sizes, sct, size_factor, color=dot_color, title=dot_title)
206215
if self.circle_data is not None:
207216
sct_circle = ax.scatter(X, Y, c='none', s=self.resized_circle_data.values.flatten(),
208217
edgecolors=circle_color, marker='o', vmin=vmin, vmax=vmax, linestyle='--')
218+
self.__draw_legend(gs_circles, sct_circle, size_factor, color=circle_color, title=circle_title,
219+
circle=True)
209220
width, height = self.width_item, self.height_item
210221
ax.set_xlim([0.5, width + 0.5])
211222
ax.set_ylim([0.6, height + 0.6])
@@ -215,49 +226,76 @@ def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
215226
ax.set_yticklabels(self.size_data.index.tolist())
216227
ax.tick_params(axis='y', length=5, labelsize=15, direction='out')
217228
ax.tick_params(axis='x', length=5, labelsize=15, direction='out')
218-
return sct, sct_circle
219229

220-
@staticmethod
221-
def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax, ylabel):
222-
gradient = np.linspace(1, 0, 500)
223-
gradient = gradient[:, np.newaxis]
224-
_ = ax.imshow(gradient, aspect='auto', cmap=cmap, origin='upper', extent=[.2, 0.3, 0.5, -0.5])
225-
ax.set_xticks([])
226-
ax.set_yticks([])
227-
ax_cbar2 = ax.twinx()
228-
_ = ax_cbar2.set_yticks([0, 1000])
229-
if vmax is None:
230-
vmax = math.ceil(sct.get_array().max())
231-
if vmin is None:
232-
vmin = math.floor(sct.get_array().min())
233-
_ = ax_cbar2.set_yticklabels([vmin, vmax])
234-
_ = ax_cbar2.set_ylabel(ylabel)
230+
def __draw_color_bar(self, gs: mpl.gridspec.GridSpec,
231+
sct: Union[mpl.collections.PathCollection, Sequence[mpl.collections.PathCollection]],
232+
cmap, vmin, vmax, ylabel):
233+
def _draw_color_bars_core(axes, path_collection: mpl.collections.PathCollection, _cmap, _vmin, _vmax, _ylabel):
234+
gradient = np.linspace(1, 0, 500)
235+
gradient = gradient[:, np.newaxis]
236+
_ = axes.imshow(gradient, aspect='auto', cmap=_cmap, origin='upper')
237+
axes.set_xticks([])
238+
axes.set_yticks([])
239+
ax_cbar2 = axes.twinx()
240+
if _vmax is None:
241+
_vmax = math.ceil(path_collection.get_array().max())
242+
if _vmin is None:
243+
_vmin = math.floor(path_collection.get_array().min())
244+
if _ylabel:
245+
_ = ax_cbar2.set_yticks([0, 1000])
246+
_ = ax_cbar2.set_yticklabels([_vmin, _vmax])
247+
else:
248+
_ = ax_cbar2.set_yticks([])
249+
_ = ax_cbar2.set_ylabel(_ylabel)
235250

236-
@staticmethod
237-
def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, circle=False, color=None):
238-
handles, labels = sct.legend_elements(prop="sizes", alpha=1,
239-
func=lambda x: x / size_factor,
240-
color=color
241-
)
242-
if len(handles) > 3:
243-
handles = np.asarray(handles)
244-
labels = np.asarray(labels)
245-
handles = handles[[0, math.ceil(len(handles) / 2), -1]]
246-
labels = labels[[0, math.ceil(len(labels) / 2), -1]]
247-
if circle:
248-
from matplotlib.lines import Line2D
249-
for i, _item in enumerate(handles):
250-
xdata, ydata = _item.get_data()
251-
marker_size = _item.get_markersize()
252-
handles[i] = Line2D(xdata, ydata, color='white', marker='$\u25CC$',
253-
markeredgecolor=color, markersize=marker_size)
254-
_ = ax.legend(handles, labels, title=title, loc='center left') # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
255-
ax.set_xticks([])
256-
ax.set_yticks([])
257-
ax.spines['top'].set_visible(False)
258-
ax.spines['bottom'].set_visible(False)
259-
ax.spines['left'].set_visible(False)
260-
ax.spines['right'].set_visible(False)
251+
fig = self.figure
252+
if isinstance(sct, mpl.collections.PathCollection):
253+
ax = fig.add_subplot(gs)
254+
_draw_color_bars_core(ax, sct, cmap, vmin, vmax, ylabel)
255+
else:
256+
new_gs = gridspec.GridSpecFromSubplotSpec(1, len(sct), wspace=.2, subplot_spec=gs)
257+
n_sct = len(sct) - 1
258+
for i, (_sct, _cmap) in enumerate(zip(sct, cmap)):
259+
ax = fig.add_subplot(new_gs[0, i])
260+
_ylabel = ylabel if n_sct == i else ''
261+
_draw_color_bars_core(ax, _sct, _cmap, vmin, vmax, _ylabel=_ylabel)
262+
263+
def __draw_legend(self, gs: mpl.gridspec.GridSpec,
264+
sct: mpl.collections.PathCollection,
265+
size_factor, title, circle=False, color=None):
266+
def __draw_legend_core(_ax, _sct, _size_factor, _title, _circle=False, _color=None):
267+
handles, labels = _sct.legend_elements(prop="sizes", alpha=1,
268+
func=lambda x: x / _size_factor,
269+
color=_color
270+
)
271+
if len(handles) > 3:
272+
handles = np.asarray(handles)
273+
labels = np.asarray(labels)
274+
handles = handles[[0, math.ceil(len(handles) / 2), -1]]
275+
labels = labels[[0, math.ceil(len(labels) / 2), -1]]
276+
if _circle:
277+
from matplotlib.lines import Line2D
278+
for j, _item in enumerate(handles):
279+
xdata, ydata = _item.get_data()
280+
marker_size = _item.get_markersize()
281+
handles[j] = Line2D(xdata, ydata, color='white', marker='$\u25CC$',
282+
markeredgecolor=_color, markersize=marker_size)
283+
_ = _ax.legend(handles, labels, title=_title, loc='center left', frameon=False)
284+
_ax.set_xticks([])
285+
_ax.set_yticks([])
286+
for item in ['top', 'bottom', 'left', 'right']:
287+
_ax.spines[item].set_visible(False)
288+
289+
fig = self.figure
290+
if isinstance(sct, mpl.collections.PathCollection):
291+
ax = fig.add_subplot(gs)
292+
__draw_legend_core(ax, sct, size_factor, title, circle, color)
293+
else:
294+
new_gs = gridspec.GridSpecFromSubplotSpec(1, len(sct), wspace=.5, subplot_spec=gs)
295+
for i, (_sct, _color, _title) in enumerate(zip(sct, color, title)):
296+
ax = fig.add_subplot(new_gs[0, i])
297+
ax.set_facecolor((0, 0, 0, 0))
298+
__draw_legend_core(ax, _sct, size_factor, _title, circle, _color)
261299

262300
def __cluster_matrix(self, axis=0, **kwargs):
263301
from .hierarchical import cluster_hierarchy
@@ -318,25 +356,9 @@ def plot(self, size_factor: float = 15,
318356
self.__preprocess_data(size_factor, cluster_row=cluster_row, cluster_col=cluster_col,
319357
**cluster_kws if cluster_kws is not None else {}
320358
)
321-
ax, ax_cbar, ax_sizes, ax_circles, ax_row_bands, ax_col_bands, fig = self.__get_figure()
322-
scatter, sct_circle = self.__draw_dotplot(ax, cmap, vmin, vmax, **kwargs)
323-
# todo
324-
if isinstance(scatter, Sequence):
325-
pass
326-
else:
327-
self.__draw_legend(ax_sizes, scatter, size_factor,
328-
color=kwargs.get('dot_color', '#58000C'), # dot legend color
329-
title=kwargs.get('dot_title', 'Sizes'))
330-
# todo
331-
if self.color_data is not None:
332-
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax,
333-
ylabel=kwargs.get('colorbar_title', '-log10(pvalue)'))
334-
if sct_circle is not None:
335-
self.__draw_legend(ax_circles, sct_circle, size_factor,
336-
color=kwargs.get('circle_color', '#000000'),
337-
title=kwargs.get('circle_title', 'Circles'),
338-
circle=True)
339-
359+
ax, gs_cbar_legend, gs_sizes_legend, gs_circles_legend, ax_row_bands, ax_col_bands, fig = self.__get_figure()
360+
self.__draw_dotplot(ax, cmap, vmin, vmax, size_factor, gs_cbar=gs_cbar_legend, gs_sizes=gs_sizes_legend,
361+
gs_circles=gs_circles_legend, **kwargs)
340362
if self.col_colors is not None:
341363
from .annotation_bands import draw_heatmap
342364
color_band_kws = {} if color_band_kws is None else color_band_kws
@@ -368,7 +390,7 @@ def resolve_mask(mask_dataframe: pd.DataFrame):
368390
group_masks.append(group_mask)
369391
if n_group < 2:
370392
raise ValueError('group number<2')
371-
return group_masks, n_group
393+
return group_masks, n_group, groups
372394

373395
def __str__(self):
374396
return 'DotPlot object with data point in shape %s' % str(self.size_data.shape)

0 commit comments

Comments
 (0)