|
1 | | -import math |
2 | | -from os import PathLike |
3 | | -from typing import Union, Sequence, Callable |
| 1 | +from .cmap import get_colormap |
| 2 | +from .core import DotPlot |
| 3 | +from .utils import merge_clusterprofile_results |
4 | 4 |
|
5 | | -import matplotlib as mpl |
6 | | -import numpy as np |
7 | | -import pandas as pd |
8 | | -from matplotlib import gridspec |
9 | | -from matplotlib import pyplot as plt |
10 | | - |
11 | | -mpl.rcParams['pdf.fonttype'] = 42 |
12 | | -mpl.rcParams["font.sans-serif"] = "Arial" |
13 | | - |
14 | | - |
15 | | -class DotPlot(object): |
16 | | - DEFAULT_ITEM_HEIGHT = 0.3 |
17 | | - DEFAULT_ITEM_WIDTH = 0.3 |
18 | | - DEFAULT_LEGENDS_WIDTH = .45 |
19 | | - MIN_FIGURE_HEIGHT = 3 |
20 | | - |
21 | | - def __init__(self, df_size: pd.DataFrame, |
22 | | - df_color: Union[pd.DataFrame, None] = None, |
23 | | - df_circle: Union[pd.DataFrame, None] = None, |
24 | | - ): |
25 | | - """ |
26 | | - Construction a `DotPlot` object from `df_size` and `df_color` |
27 | | -
|
28 | | - :param df_size: the DataFrame object represents the scatter size in dotplot |
29 | | - :param df_color: the DataFrame object represents the color in dotplot |
30 | | - """ |
31 | | - __slots__ = ['size_data', 'resized_size_data', |
32 | | - 'color_data', 'height_item', 'width_item', |
33 | | - 'circle_data', 'resized_circle_data' |
34 | | - ] |
35 | | - if (df_color is not None) & (df_size.shape != df_color.shape): |
36 | | - raise ValueError('df_size and df_color should have the same dimension') |
37 | | - self.size_data = df_size |
38 | | - self.color_data = df_color |
39 | | - self.circle_data = df_circle |
40 | | - self.height_item, self.width_item = df_size.shape |
41 | | - self.resized_size_data: pd.DataFrame |
42 | | - self.resized_circle_data: pd.DataFrame |
43 | | - |
44 | | - def __get_figure(self): |
45 | | - _text_max = math.ceil(self.size_data.index.map(len).max() / 15) |
46 | | - mainplot_height = self.height_item * self.DEFAULT_ITEM_HEIGHT |
47 | | - mainplot_width = ( |
48 | | - (_text_max + self.width_item) * self.DEFAULT_ITEM_WIDTH |
49 | | - ) |
50 | | - figure_height = max([self.MIN_FIGURE_HEIGHT, mainplot_height]) |
51 | | - figure_width = mainplot_width + self.DEFAULT_LEGENDS_WIDTH |
52 | | - plt.style.use('seaborn-white') |
53 | | - fig = plt.figure(figsize=(figure_width, figure_height)) |
54 | | - gs = gridspec.GridSpec(nrows=3, ncols=2, wspace=0.15, hspace=0.15, |
55 | | - width_ratios=[mainplot_width, self.DEFAULT_LEGENDS_WIDTH]) |
56 | | - ax = fig.add_subplot(gs[:, 0]) |
57 | | - ax_cbar = fig.add_subplot(gs[2, 1]) |
58 | | - ax_sizes = fig.add_subplot(gs[0, 1]) |
59 | | - ax_circles = fig.add_subplot(gs[1, 1]) |
60 | | - return ax, ax_cbar, ax_sizes, ax_circles, fig |
61 | | - |
62 | | - @classmethod |
63 | | - def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str, sizes_key: str, |
64 | | - color_key: Union[None, str] = None, circle_key: Union[None, str] = None, |
65 | | - selected_item: Union[None, Sequence] = None, |
66 | | - selected_group: Union[None, Sequence] = None, *, |
67 | | - sizes_func: Union[None, Callable] = None, color_func: Union[None, Callable] = None |
68 | | - ): |
69 | | - """ |
70 | | -
|
71 | | - class method for conveniently constructing DotPlot from tidy data |
72 | | -
|
73 | | - :param data_frame: |
74 | | - :param item_key: |
75 | | - :param group_key: |
76 | | - :param sizes_key: |
77 | | - :param color_key: |
78 | | - :param selected_item: default None, if specified, this should be subsets of `item_key` in `data_frame` |
79 | | - alternatively, this param can be used as self-defined item order definition. |
80 | | - :param selected_group: Same as `selected_item`, for group order and subset groups |
81 | | - :param sizes_func: |
82 | | - :param color_func: |
83 | | - :param circle_key: |
84 | | - :return: |
85 | | - """ |
86 | | - keys = [v for v in [item_key, group_key, sizes_key, color_key, circle_key] if v is not None] |
87 | | - data_frame = data_frame[keys] |
88 | | - _original_item_order = data_frame[item_key].tolist() |
89 | | - _original_item_order = _original_item_order[::-1] |
90 | | - if sizes_func is not None: |
91 | | - data_frame[sizes_key] = data_frame[sizes_key].map(sizes_func) |
92 | | - if color_func is not None: |
93 | | - data_frame[color_key] = data_frame[color_key].map(color_func) |
94 | | - keys.remove(item_key) |
95 | | - keys.remove(group_key) |
96 | | - data_frame = data_frame.pivot(index=item_key, columns=group_key, values=keys) |
97 | | - data_frame = data_frame.loc[_original_item_order, :] |
98 | | - if selected_item is not None: |
99 | | - data_frame = data_frame.loc[selected_item, :] |
100 | | - if selected_group is not None: |
101 | | - data_frame = data_frame.loc[:, selected_group] |
102 | | - data_frame.columns = data_frame.columns.map(lambda x: '_'.join(x)) |
103 | | - data_frame = data_frame.fillna(0) |
104 | | - |
105 | | - sizes_df, color_df, circle_df = (None, None, None) |
106 | | - sizes_df = data_frame.loc[:, data_frame.columns.str.startswith(sizes_key)] |
107 | | - if color_key is not None: |
108 | | - color_df = data_frame.loc[:, data_frame.columns.str.startswith(color_key)] |
109 | | - if circle_key is not None: |
110 | | - circle_df = data_frame.loc[:, data_frame.columns.str.startswith(circle_key)] |
111 | | - return cls(sizes_df, color_df, circle_df) |
112 | | - |
113 | | - def __get_coordinates(self, size_factor): |
114 | | - X = list(range(1, self.width_item + 1)) * self.height_item |
115 | | - Y = sorted(list(range(1, self.height_item + 1)) * self.width_item) |
116 | | - self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor) |
117 | | - self.resized_circle_data = self.circle_data.applymap(func=lambda x: x * size_factor) |
118 | | - return X, Y |
119 | | - |
120 | | - def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax): |
121 | | - X, Y = self.__get_coordinates(size_factor) |
122 | | - if self.color_data is None: |
123 | | - sct = ax.scatter(X, Y, c='r', cmap=cmap, s=self.resized_size_data.values.flatten(), |
124 | | - edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax) |
125 | | - else: |
126 | | - sct = ax.scatter(X, Y, c=self.color_data.values.flatten(), s=self.resized_size_data.values.flatten(), |
127 | | - edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap) |
128 | | - sct_circle = None |
129 | | - if self.circle_data is not None: |
130 | | - sct_circle = ax.scatter(X, Y, c='', edgecolors='k', marker='o', linestyle='--', |
131 | | - s=self.resized_circle_data.values.flatten()) |
132 | | - width, height = self.width_item, self.height_item |
133 | | - ax.set_xlim([0.5, width + 0.5]) |
134 | | - ax.set_ylim([0.6, height + 0.6]) |
135 | | - ax.set_xticks(range(1, width + 1)) |
136 | | - ax.set_yticks(range(1, height + 1)) |
137 | | - ax.set_xticklabels(self.size_data.columns.tolist(), rotation='vertical') |
138 | | - ax.set_yticklabels(self.size_data.index.tolist()) |
139 | | - ax.tick_params(axis='y', length=5, labelsize=15, direction='out') |
140 | | - ax.tick_params(axis='x', length=5, labelsize=15, direction='out') |
141 | | - return sct, sct_circle |
142 | | - |
143 | | - @staticmethod |
144 | | - def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax): |
145 | | - gradient = np.linspace(1, 0, 500) |
146 | | - gradient = gradient[:, np.newaxis] |
147 | | - _ = ax.imshow(gradient, aspect='auto', cmap=cmap, origin='upper', extent=[.2, 0.3, 0.5, -0.5]) |
148 | | - ax.set_xticks([]) |
149 | | - ax.set_yticks([]) |
150 | | - ax_cbar2 = ax.twinx() |
151 | | - _ = ax_cbar2.set_yticks([0, 1000]) |
152 | | - if vmax is None: |
153 | | - vmax = math.ceil(sct.get_array().max()) |
154 | | - if vmin is None: |
155 | | - vmin = math.floor(sct.get_array().min()) |
156 | | - _ = ax_cbar2.set_yticklabels([vmin, vmax]) |
157 | | - _ = ax_cbar2.set_ylabel('-log10(pvalue)') |
158 | | - |
159 | | - @staticmethod |
160 | | - def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, circle=False, color=None): |
161 | | - print(id(sct)) |
162 | | - handles, labels = sct.legend_elements(prop="sizes", alpha=1, |
163 | | - func=lambda x: x / size_factor, |
164 | | - color=color |
165 | | - ) |
166 | | - if len(handles) > 3: |
167 | | - handles = np.asarray(handles) |
168 | | - labels = np.asarray(labels) |
169 | | - handles = handles[[0, math.ceil(len(handles) / 2), -1]] |
170 | | - labels = labels[[0, math.ceil(len(labels) / 2), -1]] |
171 | | - if circle: |
172 | | - from matplotlib.lines import Line2D |
173 | | - for i, _item in enumerate(handles): |
174 | | - xdata, ydata = _item.get_data() |
175 | | - marker_size = _item.get_markersize() |
176 | | - handles[i] = Line2D(xdata, ydata, color='white', marker='$\u25CC$', |
177 | | - markeredgecolor=color, markersize=marker_size) |
178 | | - _ = ax.legend(handles, labels, title=title, loc='center left') # bbox_to_anchor=(0.9, 0., 0.4, 0.4) |
179 | | - ax.set_xticks([]) |
180 | | - ax.set_yticks([]) |
181 | | - ax.spines['top'].set_visible(False) |
182 | | - ax.spines['bottom'].set_visible(False) |
183 | | - ax.spines['left'].set_visible(False) |
184 | | - ax.spines['right'].set_visible(False) |
185 | | - |
186 | | - def plot(self, size_factor: float = 15, |
187 | | - vmin: float = 0, vmax: float = None, |
188 | | - path: Union[PathLike, None] = None, |
189 | | - cmap: Union[str, mpl.colors.Colormap] = 'Reds'): |
190 | | - """ |
191 | | -
|
192 | | - :param size_factor: `size factor` * `value` for the actually representation of scatter size in the final figure |
193 | | - :param vmin: `vmin` in `matplotlib.pyplot.scatter` |
194 | | - :param vmax: `vmax` in `matplotlib.pyplot.scatter` |
195 | | - :param path: path to save the figure |
196 | | - :param cmap: color map supported by matplotlib |
197 | | - :return: |
198 | | - """ |
199 | | - ax, ax_cbar, ax_sizes, ax_circles, fig = self.__get_figure() |
200 | | - scatter, sct_circle = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax) |
201 | | - self.__draw_legend(ax_sizes, scatter, size_factor, title='Sizes', color='#58000C') |
202 | | - if sct_circle is not None: |
203 | | - self.__draw_legend(ax_circles, sct_circle, size_factor, title='Circles', circle=True, color='k') |
204 | | - self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax) |
205 | | - if path: |
206 | | - fig.savefig(path, dpi=300, bbox_inches='tight') # |
207 | | - return scatter |
208 | | - |
209 | | - def __str__(self): |
210 | | - return 'DotPlot object with data point in shape %s' % str(self.size_data.shape) |
211 | | - |
212 | | - __repr__ = __str__ |
| 5 | +__all__ = ['DotPlot', 'merge_clusterprofile_results', 'get_colormap'] |
0 commit comments