Skip to content

Commit 9146fae

Browse files
committed
add functionality for #9
1 parent 925a47b commit 9146fae

2 files changed

Lines changed: 76 additions & 21 deletions

File tree

dotplot/core.py

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from os import PathLike
3-
from typing import Union, Sequence, Callable, Dict
3+
from typing import Union, Sequence, Callable, Dict, List
44

55
import matplotlib as mpl
66
import numpy as np
@@ -11,6 +11,9 @@
1111
mpl.rcParams['pdf.fonttype'] = 42
1212
mpl.rcParams["font.sans-serif"] = "Arial"
1313

14+
CMAPS_PRESET = ('Reds', 'Blues', 'Purples', 'Oranges', 'Greens', 'Greys')
15+
COLORS_PRESET = ('r', 'b', '#BA55D3', '#FFA500', 'g', '#C0C0C0')
16+
1417

1518
class DotPlot(object):
1619
DEFAULT_ITEM_HEIGHT = 0.3
@@ -24,6 +27,7 @@ def __init__(self, df_size: pd.DataFrame,
2427
df_circle: Union[pd.DataFrame, None] = None,
2528
row_colors: Union[pd.DataFrame, None] = None,
2629
col_colors: Union[pd.DataFrame, None] = None,
30+
mask_frames: Union[pd.DataFrame, None] = None
2731
):
2832
"""
2933
Construction a `DotPlot` object from `df_size` and `df_color`
@@ -33,7 +37,7 @@ def __init__(self, df_size: pd.DataFrame,
3337
"""
3438
__slots__ = ['size_data', 'resized_size_data',
3539
'color_data', 'height_item', 'width_item',
36-
'circle_data', 'resized_circle_data', 'row_colors', 'col_colors'
40+
'circle_data', 'resized_circle_data', 'row_colors', 'col_colors', 'mask_frames'
3741
]
3842
if df_color is not None and df_size.shape != df_color.shape:
3943
raise ValueError('df_size and df_color should have the same dimension')
@@ -43,6 +47,8 @@ def __init__(self, df_size: pd.DataFrame,
4347
raise ValueError('row_colors has the wrong shape')
4448
if col_colors is not None and df_size.shape[1] != len(col_colors):
4549
raise ValueError('col_colors has the wrong shape')
50+
if mask_frames is not None and df_size.shape != mask_frames.shape:
51+
raise ValueError('df_size and mask_frames should have the same dimension')
4652

4753
self.size_data = df_size
4854
self.color_data = df_color
@@ -52,6 +58,7 @@ def __init__(self, df_size: pd.DataFrame,
5258
self.col_colors = col_colors
5359
self.resized_size_data: Union[pd.DataFrame, None] = None
5460
self.resized_circle_data: Union[pd.DataFrame, None] = None
61+
self.mask_frames = mask_frames
5562
if (self.row_colors is None) and (self.col_colors is None):
5663
self.col_colors = pd.DataFrame({'': ['#FFFFFF'] * df_size.shape[1]},
5764
index=df_size.columns.tolist())
@@ -101,6 +108,7 @@ def __get_figure(self):
101108
ax_abandon.axis('off')
102109
return ax, ax_cbar, ax_sizes, ax_circles, ax_row_bands, ax_col_bands, fig
103110

111+
# TODO update with the newest version of __init__
104112
@classmethod
105113
def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str, sizes_key: str,
106114
color_key: Union[None, str] = None, circle_key: Union[None, str] = None,
@@ -158,19 +166,41 @@ def __get_coordinates(self):
158166
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
159167
return X, Y
160168

161-
def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax, **kws):
169+
def __draw_dotplot(self, ax, cmap, vmin, vmax, **kws):
162170
dot_color = kws.get('dot_color', '#58000C')
163171
circle_color = kws.get('circle_color', '#000000')
164172
kws = kws.copy()
165173
for _value in ['dot_title', 'circle_title', 'colorbar_title', 'dot_color', 'circle_color']:
166174
_ = kws.pop(_value, None)
167-
168175
X, Y = self.__get_coordinates()
169-
if self.color_data is None:
170-
sct = ax.scatter(X, Y, c=dot_color, s=self.resized_size_data.values.flatten(),
171-
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap, **kws)
176+
resized_size_data_array = self.resized_size_data.values.flatten()
177+
color_data_array_or_str = dot_color if self.color_data is None else self.color_data.values.flatten()
178+
sct: Union[List[mpl.collections.PathCollection], mpl.collections.PathCollection] = []
179+
if self.mask_frames is not None:
180+
masks, n_masks = self.resolve_mask(self.mask_frames)
181+
if isinstance(color_data_array_or_str, np.ndarray):
182+
if isinstance(cmap, (str, mpl.colors.Colormap)):
183+
cmap = CMAPS_PRESET
184+
if len(cmap) < n_masks:
185+
raise ValueError('too many groups to draw with limited color map')
186+
for _, (mask, _cmap) in enumerate(zip(masks, cmap)):
187+
masked_resized_size_data_array = np.ma.masked_array(resized_size_data_array, mask=mask)
188+
_sct = ax.scatter(X, Y, c=color_data_array_or_str, s=masked_resized_size_data_array,
189+
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=_cmap, **kws)
190+
sct.append(_sct)
191+
else:
192+
if isinstance(dot_color, str):
193+
dot_color = COLORS_PRESET
194+
if len(dot_color) < n_masks:
195+
raise ValueError('too many groups to draw with limited color')
196+
for _, (mask, _dot_color) in enumerate(zip(masks, dot_color)):
197+
masked_resized_size_data_array = np.ma.masked_array(resized_size_data_array, mask=mask)
198+
_sct = ax.scatter(X, Y, c=_dot_color, s=masked_resized_size_data_array,
199+
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, **kws)
200+
sct.append(_sct)
201+
172202
else:
173-
sct = ax.scatter(X, Y, c=self.color_data.values.flatten(), s=self.resized_size_data.values.flatten(),
203+
sct = ax.scatter(X, Y, c=color_data_array_or_str, s=resized_size_data_array,
174204
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap, **kws)
175205
sct_circle = None
176206
if self.circle_data is not None:
@@ -250,7 +280,6 @@ def __cluster_matrix(self, axis=0, **kwargs):
250280
setattr(self, _obj_attr, _obj)
251281

252282
def __preprocess_data(self, size_factor, cluster_row=False, cluster_col=False, **kwargs):
253-
254283
if cluster_row or cluster_col:
255284
if cluster_row:
256285
self.__cluster_matrix(axis=0, **kwargs)
@@ -263,7 +292,8 @@ def __preprocess_data(self, size_factor, cluster_row=False, cluster_col=False, *
263292
def plot(self, size_factor: float = 15,
264293
vmin: float = 0, vmax: float = None,
265294
path: Union[PathLike, None] = None,
266-
cmap: Union[str, mpl.colors.Colormap] = 'Reds',
295+
cmap: Union[str, mpl.colors.Colormap,
296+
Sequence[Union[str, mpl.colors.Colormap]]] = 'Reds',
267297
cluster_row: bool = False, cluster_col: bool = False,
268298
cluster_kws: Union[Dict, None] = None,
269299
color_band_kws: Union[Dict, None] = None,
@@ -275,31 +305,37 @@ def plot(self, size_factor: float = 15,
275305
:param vmin: `vmin` in `matplotlib.pyplot.scatter`
276306
:param vmax: `vmax` in `matplotlib.pyplot.scatter`
277307
:param path: path to save the figure
278-
:param cmap: color map supported by matplotlib
308+
:param cmap: color map supported by matplotlib, can be sequence of cmap when drawing grouped dotplots
279309
:param cluster_row, whether to cluster the row
280310
:param cluster_col, whether to cluster the col
281311
:param cluster_kws, key args for cluster, including `cluster_method`, `cluster_metric`, 'cluster_n'
282312
:param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
283-
other kwargs are passed to `matplotlib.Axes.scatter`
313+
other kwargs are passed to `matplotlib.Axes.scatter`. Notably, dot_color can be a
314+
color sequence when drawing grouped dotplots
284315
:param color_band_kws: this kwargs was passed to `matplotlib.axes.Axes.pcolormesh`
285316
:return:
286317
"""
287318
self.__preprocess_data(size_factor, cluster_row=cluster_row, cluster_col=cluster_col,
288319
**cluster_kws if cluster_kws is not None else {}
289320
)
290321
ax, ax_cbar, ax_sizes, ax_circles, ax_row_bands, ax_col_bands, fig = self.__get_figure()
291-
scatter, sct_circle = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
292-
self.__draw_legend(ax_sizes, scatter, size_factor,
293-
color=kwargs.get('dot_color', '#58000C'), # dot legend color
294-
title=kwargs.get('dot_title', 'Sizes'))
322+
scatter, sct_circle = self.__draw_dotplot(ax, cmap, vmin, vmax, **kwargs)
323+
# todo
324+
if isinstance(scatter, Sequence):
325+
pass
326+
else:
327+
self.__draw_legend(ax_sizes, scatter, size_factor,
328+
color=kwargs.get('dot_color', '#58000C'), # dot legend color
329+
title=kwargs.get('dot_title', 'Sizes'))
330+
# todo
331+
if self.color_data is not None:
332+
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax,
333+
ylabel=kwargs.get('colorbar_title', '-log10(pvalue)'))
295334
if sct_circle is not None:
296335
self.__draw_legend(ax_circles, sct_circle, size_factor,
297336
color=kwargs.get('circle_color', '#000000'),
298337
title=kwargs.get('circle_title', 'Circles'),
299338
circle=True)
300-
if self.color_data is not None:
301-
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax,
302-
ylabel=kwargs.get('colorbar_title', '-log10(pvalue)'))
303339

304340
if self.col_colors is not None:
305341
from .annotation_bands import draw_heatmap
@@ -311,10 +347,28 @@ def plot(self, size_factor: float = 15,
311347
from .annotation_bands import draw_heatmap
312348
draw_heatmap(self.row_colors, axes=ax_row_bands,
313349
index_order=self.size_data.index.tolist(), axis=0, **color_band_kws)
314-
315350
if path:
316351
fig.savefig(path, dpi=300, bbox_inches='tight')
317-
return scatter
352+
return fig
353+
354+
@staticmethod
355+
def resolve_mask(mask_dataframe: pd.DataFrame):
356+
mask_dataframe = mask_dataframe.applymap(func=lambda x: str(x))
357+
groups = np.unique(mask_dataframe.to_numpy())
358+
mappings = dict(zip(groups, [0] * len(groups)))
359+
group_masks = []
360+
n_group = 0
361+
for group in groups:
362+
if group == 'nan':
363+
continue
364+
n_group += 1
365+
_mappings = mappings.copy()
366+
_mappings.update({group: 1})
367+
group_mask = mask_dataframe.applymap(func=lambda x: _mappings[x])
368+
group_masks.append(group_mask)
369+
if n_group < 2:
370+
raise ValueError('group number<2')
371+
return group_masks, n_group
318372

319373
def __str__(self):
320374
return 'DotPlot object with data point in shape %s' % str(self.size_data.shape)

dotplot/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ def merge_clusterprofile_results(dataframes, groups, group_key='group', term_lis
2626
merged_df['pvalue'] = merged_df['pvalue'].map(lambda x: -np.log10(x))
2727
merged_df['p.adjust'] = merged_df['p.adjust'].map(lambda x: -np.log10(x))
2828
return merged_df
29+

0 commit comments

Comments
 (0)