Skip to content

Commit cd63d98

Browse files
committed
fix the cluster bug
1 parent a5cdb6e commit cd63d98

1 file changed

Lines changed: 21 additions & 18 deletions

File tree

dotplot/core.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(self, df_size: pd.DataFrame,
4848
self.color_data = df_color
4949
self.circle_data = df_circle
5050
self.height_item, self.width_item = df_size.shape
51-
# TODO code logic need to argument
5251
self.row_colors = row_colors
5352
self.col_colors = col_colors
5453
self.resized_size_data: Union[pd.DataFrame, None] = None
@@ -227,29 +226,33 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, c
227226
ax.spines['left'].set_visible(False)
228227
ax.spines['right'].set_visible(False)
229228

230-
def __preprocess_data(self, size_factor, cluster_row=False, cluster_col=False, **kwargs):
231-
229+
def __cluster_matrix(self, axis=0, **kwargs):
230+
from .hierarchical import cluster_hierarchy
232231
method = kwargs.get('cluster_method', 'ward')
233232
metric = kwargs.get('cluster_metric', 'euclidean')
234233
n_clusters = kwargs.get('cluster_n', None)
234+
_index = cluster_hierarchy(self.size_data, axis=axis, method=method,
235+
metric=metric, n_clusters=n_clusters)
236+
obj_data = self.__dict__.copy()
237+
for _obj_attr, _obj in obj_data.items():
238+
if (not _obj_attr.startswith('__')) and isinstance(_obj, (pd.DataFrame, pd.Series)):
239+
if _obj_attr in ('row_colors', 'col_colors'): # TODO may change the action in the future
240+
continue
241+
if axis == 0:
242+
_obj = _obj.loc[_index, :]
243+
elif axis == 1:
244+
_obj = _obj.loc[:, _index]
245+
else:
246+
raise ValueError('axis should be 0 or 1.')
247+
setattr(self, _obj_attr, _obj)
248+
249+
def __preprocess_data(self, size_factor, cluster_row=False, cluster_col=False, **kwargs):
235250

236251
if cluster_row or cluster_col:
237-
from .hierarchical import cluster_hierarchy
238252
if cluster_row:
239-
_index = cluster_hierarchy(self.size_data, axis=0, method=method,
240-
metric=metric, n_clusters=n_clusters)
241-
else:
242-
_index = cluster_hierarchy(self.size_data, axis=1, method=method,
243-
metric=metric, n_clusters=n_clusters)
244-
obj_data = self.__dict__.copy()
245-
for _obj_attr, _obj in obj_data.items():
246-
if not _obj_attr.startswith('__'):
247-
if isinstance(_obj, pd.DataFrame):
248-
if cluster_row:
249-
_obj = _obj.loc[_index, :]
250-
if cluster_col:
251-
_obj = _obj.loc[:, _index]
252-
setattr(self, _obj_attr, _obj)
253+
self.__cluster_matrix(axis=0, **kwargs)
254+
if cluster_col:
255+
self.__cluster_matrix(axis=1, **kwargs)
253256
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
254257
if self.circle_data is not None:
255258
self.resized_circle_data = self.circle_data.applymap(func=lambda x: x * size_factor)

0 commit comments

Comments
 (0)