Skip to content

Commit f2a4485

Browse files
committed
update plot func with cluster utility
1 parent 717e215 commit f2a4485

1 file changed

Lines changed: 36 additions & 5 deletions

File tree

dotplot/core.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from os import PathLike
3-
from typing import Union, Sequence, Callable
3+
from typing import Union, Sequence, Callable, Dict
44

55
import matplotlib as mpl
66
import numpy as np
@@ -128,9 +128,6 @@ class method for conveniently constructing DotPlot from tidy data
128128
def __get_coordinates(self, size_factor):
129129
X = list(range(1, self.width_item + 1)) * self.height_item
130130
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
131-
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
132-
if self.circle_data is not None:
133-
self.resized_circle_data = self.circle_data.applymap(func=lambda x: x * size_factor)
134131
return X, Y
135132

136133
def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax, **kws):
@@ -204,11 +201,39 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, c
204201
ax.spines['left'].set_visible(False)
205202
ax.spines['right'].set_visible(False)
206203

204+
def __preprocess_data(self, size_factor, cluster_row=False, cluster_col=False, **kwargs):
205+
206+
method = kwargs.get('cluster_method', 'ward')
207+
metric = kwargs.get('cluster_metric', 'eulidean')
208+
n_clusters = kwargs.get('cluster_n', None)
209+
210+
if cluster_row or cluster_col:
211+
from .hierarchical import cluster_hierarchy
212+
if cluster_row:
213+
_index = cluster_hierarchy(self.size_data, axis=0, method=method,
214+
metric=metric, n_clusters=n_clusters)
215+
else:
216+
_index = cluster_hierarchy(self.size_data, axis=1, method=method,
217+
metric=metric, n_clusters=n_clusters)
218+
for item in self.__slots__:
219+
if hasattr(self, item):
220+
obj_attr = getattr(self, item)
221+
if isinstance(obj_attr, pd.DataFrame):
222+
if cluster_row:
223+
obj_attr = obj_attr.loc[_index, :]
224+
if cluster_col:
225+
obj_attr = obj_attr.loc[:, _index]
226+
setattr(self, item, obj_attr)
227+
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
228+
if self.circle_data is not None:
229+
self.resized_circle_data = self.circle_data.applymap(func=lambda x: x * size_factor)
230+
207231
def plot(self, size_factor: float = 15,
208232
vmin: float = 0, vmax: float = None,
209233
path: Union[PathLike, None] = None,
210234
cmap: Union[str, mpl.colors.Colormap] = 'Reds',
211-
**kwargs
235+
cluster_row: bool = False, cluster_col: bool = False,
236+
cluster_kws: Union[Dict, None] = None, **kwargs
212237
):
213238
"""
214239
@@ -219,8 +244,14 @@ def plot(self, size_factor: float = 15,
219244
:param cmap: color map supported by matplotlib
220245
:param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
221246
other kwargs are passed to `matplotlib.Axes.scatter`
247+
:param cluster_row, whether to cluster the row
248+
:param cluster_col, whether to cluster the row
249+
:param cluster_kws, key args for cluster, including `cluster_method`, `cluster_metric`, 'cluster_n'
222250
:return:
223251
"""
252+
self.__preprocess_data(size_factor, cluster_row=cluster_row, cluster_col=cluster_col,
253+
**cluster_kws if cluster_kws is not None else {}
254+
)
224255
ax, ax_cbar, ax_sizes, ax_circles, fig = self.__get_figure()
225256
scatter, sct_circle = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
226257
self.__draw_legend(ax_sizes, scatter, size_factor,

0 commit comments

Comments
 (0)