Skip to content

Commit 08e33e6

Browse files
committed
Refactoring the code
1 parent eac166d commit 08e33e6

2 files changed

Lines changed: 216 additions & 211 deletions

File tree

dotplot/__init__.py

Lines changed: 4 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -1,212 +1,5 @@
1-
import math
2-
from os import PathLike
3-
from typing import Union, Sequence, Callable
1+
from .cmap import get_colormap
2+
from .core import DotPlot
3+
from .utils import merge_clusterprofile_results
44

5-
import matplotlib as mpl
6-
import numpy as np
7-
import pandas as pd
8-
from matplotlib import gridspec
9-
from matplotlib import pyplot as plt
10-
11-
mpl.rcParams['pdf.fonttype'] = 42
12-
mpl.rcParams["font.sans-serif"] = "Arial"
13-
14-
15-
class DotPlot(object):
16-
DEFAULT_ITEM_HEIGHT = 0.3
17-
DEFAULT_ITEM_WIDTH = 0.3
18-
DEFAULT_LEGENDS_WIDTH = .45
19-
MIN_FIGURE_HEIGHT = 3
20-
21-
def __init__(self, df_size: pd.DataFrame,
22-
df_color: Union[pd.DataFrame, None] = None,
23-
df_circle: Union[pd.DataFrame, None] = None,
24-
):
25-
"""
26-
Construction a `DotPlot` object from `df_size` and `df_color`
27-
28-
:param df_size: the DataFrame object represents the scatter size in dotplot
29-
:param df_color: the DataFrame object represents the color in dotplot
30-
"""
31-
__slots__ = ['size_data', 'resized_size_data',
32-
'color_data', 'height_item', 'width_item',
33-
'circle_data', 'resized_circle_data'
34-
]
35-
if (df_color is not None) & (df_size.shape != df_color.shape):
36-
raise ValueError('df_size and df_color should have the same dimension')
37-
self.size_data = df_size
38-
self.color_data = df_color
39-
self.circle_data = df_circle
40-
self.height_item, self.width_item = df_size.shape
41-
self.resized_size_data: pd.DataFrame
42-
self.resized_circle_data: pd.DataFrame
43-
44-
def __get_figure(self):
45-
_text_max = math.ceil(self.size_data.index.map(len).max() / 15)
46-
mainplot_height = self.height_item * self.DEFAULT_ITEM_HEIGHT
47-
mainplot_width = (
48-
(_text_max + self.width_item) * self.DEFAULT_ITEM_WIDTH
49-
)
50-
figure_height = max([self.MIN_FIGURE_HEIGHT, mainplot_height])
51-
figure_width = mainplot_width + self.DEFAULT_LEGENDS_WIDTH
52-
plt.style.use('seaborn-white')
53-
fig = plt.figure(figsize=(figure_width, figure_height))
54-
gs = gridspec.GridSpec(nrows=3, ncols=2, wspace=0.15, hspace=0.15,
55-
width_ratios=[mainplot_width, self.DEFAULT_LEGENDS_WIDTH])
56-
ax = fig.add_subplot(gs[:, 0])
57-
ax_cbar = fig.add_subplot(gs[2, 1])
58-
ax_sizes = fig.add_subplot(gs[0, 1])
59-
ax_circles = fig.add_subplot(gs[1, 1])
60-
return ax, ax_cbar, ax_sizes, ax_circles, fig
61-
62-
@classmethod
63-
def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str, sizes_key: str,
64-
color_key: Union[None, str] = None, circle_key: Union[None, str] = None,
65-
selected_item: Union[None, Sequence] = None,
66-
selected_group: Union[None, Sequence] = None, *,
67-
sizes_func: Union[None, Callable] = None, color_func: Union[None, Callable] = None
68-
):
69-
"""
70-
71-
class method for conveniently constructing DotPlot from tidy data
72-
73-
:param data_frame:
74-
:param item_key:
75-
:param group_key:
76-
:param sizes_key:
77-
:param color_key:
78-
:param selected_item: default None, if specified, this should be subsets of `item_key` in `data_frame`
79-
alternatively, this param can be used as self-defined item order definition.
80-
:param selected_group: Same as `selected_item`, for group order and subset groups
81-
:param sizes_func:
82-
:param color_func:
83-
:param circle_key:
84-
:return:
85-
"""
86-
keys = [v for v in [item_key, group_key, sizes_key, color_key, circle_key] if v is not None]
87-
data_frame = data_frame[keys]
88-
_original_item_order = data_frame[item_key].tolist()
89-
_original_item_order = _original_item_order[::-1]
90-
if sizes_func is not None:
91-
data_frame[sizes_key] = data_frame[sizes_key].map(sizes_func)
92-
if color_func is not None:
93-
data_frame[color_key] = data_frame[color_key].map(color_func)
94-
keys.remove(item_key)
95-
keys.remove(group_key)
96-
data_frame = data_frame.pivot(index=item_key, columns=group_key, values=keys)
97-
data_frame = data_frame.loc[_original_item_order, :]
98-
if selected_item is not None:
99-
data_frame = data_frame.loc[selected_item, :]
100-
if selected_group is not None:
101-
data_frame = data_frame.loc[:, selected_group]
102-
data_frame.columns = data_frame.columns.map(lambda x: '_'.join(x))
103-
data_frame = data_frame.fillna(0)
104-
105-
sizes_df, color_df, circle_df = (None, None, None)
106-
sizes_df = data_frame.loc[:, data_frame.columns.str.startswith(sizes_key)]
107-
if color_key is not None:
108-
color_df = data_frame.loc[:, data_frame.columns.str.startswith(color_key)]
109-
if circle_key is not None:
110-
circle_df = data_frame.loc[:, data_frame.columns.str.startswith(circle_key)]
111-
return cls(sizes_df, color_df, circle_df)
112-
113-
def __get_coordinates(self, size_factor):
114-
X = list(range(1, self.width_item + 1)) * self.height_item
115-
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
116-
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
117-
self.resized_circle_data = self.circle_data.applymap(func=lambda x: x * size_factor)
118-
return X, Y
119-
120-
def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
121-
X, Y = self.__get_coordinates(size_factor)
122-
if self.color_data is None:
123-
sct = ax.scatter(X, Y, c='r', cmap=cmap, s=self.resized_size_data.values.flatten(),
124-
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax)
125-
else:
126-
sct = ax.scatter(X, Y, c=self.color_data.values.flatten(), s=self.resized_size_data.values.flatten(),
127-
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap)
128-
sct_circle = None
129-
if self.circle_data is not None:
130-
sct_circle = ax.scatter(X, Y, c='', edgecolors='k', marker='o', linestyle='--',
131-
s=self.resized_circle_data.values.flatten())
132-
width, height = self.width_item, self.height_item
133-
ax.set_xlim([0.5, width + 0.5])
134-
ax.set_ylim([0.6, height + 0.6])
135-
ax.set_xticks(range(1, width + 1))
136-
ax.set_yticks(range(1, height + 1))
137-
ax.set_xticklabels(self.size_data.columns.tolist(), rotation='vertical')
138-
ax.set_yticklabels(self.size_data.index.tolist())
139-
ax.tick_params(axis='y', length=5, labelsize=15, direction='out')
140-
ax.tick_params(axis='x', length=5, labelsize=15, direction='out')
141-
return sct, sct_circle
142-
143-
@staticmethod
144-
def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax):
145-
gradient = np.linspace(1, 0, 500)
146-
gradient = gradient[:, np.newaxis]
147-
_ = ax.imshow(gradient, aspect='auto', cmap=cmap, origin='upper', extent=[.2, 0.3, 0.5, -0.5])
148-
ax.set_xticks([])
149-
ax.set_yticks([])
150-
ax_cbar2 = ax.twinx()
151-
_ = ax_cbar2.set_yticks([0, 1000])
152-
if vmax is None:
153-
vmax = math.ceil(sct.get_array().max())
154-
if vmin is None:
155-
vmin = math.floor(sct.get_array().min())
156-
_ = ax_cbar2.set_yticklabels([vmin, vmax])
157-
_ = ax_cbar2.set_ylabel('-log10(pvalue)')
158-
159-
@staticmethod
160-
def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, circle=False, color=None):
161-
print(id(sct))
162-
handles, labels = sct.legend_elements(prop="sizes", alpha=1,
163-
func=lambda x: x / size_factor,
164-
color=color
165-
)
166-
if len(handles) > 3:
167-
handles = np.asarray(handles)
168-
labels = np.asarray(labels)
169-
handles = handles[[0, math.ceil(len(handles) / 2), -1]]
170-
labels = labels[[0, math.ceil(len(labels) / 2), -1]]
171-
if circle:
172-
from matplotlib.lines import Line2D
173-
for i, _item in enumerate(handles):
174-
xdata, ydata = _item.get_data()
175-
marker_size = _item.get_markersize()
176-
handles[i] = Line2D(xdata, ydata, color='white', marker='$\u25CC$',
177-
markeredgecolor=color, markersize=marker_size)
178-
_ = ax.legend(handles, labels, title=title, loc='center left') # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
179-
ax.set_xticks([])
180-
ax.set_yticks([])
181-
ax.spines['top'].set_visible(False)
182-
ax.spines['bottom'].set_visible(False)
183-
ax.spines['left'].set_visible(False)
184-
ax.spines['right'].set_visible(False)
185-
186-
def plot(self, size_factor: float = 15,
187-
vmin: float = 0, vmax: float = None,
188-
path: Union[PathLike, None] = None,
189-
cmap: Union[str, mpl.colors.Colormap] = 'Reds'):
190-
"""
191-
192-
:param size_factor: `size factor` * `value` for the actually representation of scatter size in the final figure
193-
:param vmin: `vmin` in `matplotlib.pyplot.scatter`
194-
:param vmax: `vmax` in `matplotlib.pyplot.scatter`
195-
:param path: path to save the figure
196-
:param cmap: color map supported by matplotlib
197-
:return:
198-
"""
199-
ax, ax_cbar, ax_sizes, ax_circles, fig = self.__get_figure()
200-
scatter, sct_circle = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
201-
self.__draw_legend(ax_sizes, scatter, size_factor, title='Sizes', color='#58000C')
202-
if sct_circle is not None:
203-
self.__draw_legend(ax_circles, sct_circle, size_factor, title='Circles', circle=True, color='k')
204-
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax)
205-
if path:
206-
fig.savefig(path, dpi=300, bbox_inches='tight') #
207-
return scatter
208-
209-
def __str__(self):
210-
return 'DotPlot object with data point in shape %s' % str(self.size_data.shape)
211-
212-
__repr__ = __str__
5+
__all__ = ['DotPlot', 'merge_clusterprofile_results', 'get_colormap']

0 commit comments

Comments
 (0)