Skip to content

Commit 1a99ec4

Browse files
committed
add color band annotation, #4
1 parent 49372fe commit 1a99ec4

2 files changed

Lines changed: 151 additions & 23 deletions

File tree

dotplot/annotation_bands.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
module for draw annotation bands.
3+
PS. This implementation mainly refers to the way of seaborn
4+
"""
5+
6+
import itertools
7+
from typing import Union, Sequence
8+
9+
import matplotlib as mpl
10+
import numpy as np
11+
import pandas as pd
12+
from matplotlib.colors import to_rgb
13+
14+
mpl.rcParams['pdf.fonttype'] = 42
15+
mpl.rcParams["font.sans-serif"] = "Arial"
16+
17+
18+
def _process_colors(colors: Union[pd.DataFrame, pd.Series],
19+
index_order: Union[Sequence[Union[str, int]], None] = None):
20+
if not isinstance(colors, (pd.DataFrame, pd.Series)):
21+
raise TypeError('`colors` should be pandas.DataFrame or pandas.Series')
22+
if index_order is not None:
23+
colors = colors.reindex(index_order)
24+
colors = colors.astype(object).fillna('white') # TODO, to alpha
25+
if isinstance(colors, pd.DataFrame):
26+
labels = list(colors.columns)
27+
colors = colors.T.values
28+
else:
29+
labels = [''] if colors.name is None else [colors.name]
30+
colors = colors.values
31+
try:
32+
to_rgb(colors[0])
33+
colors = list(map(to_rgb, colors))
34+
except ValueError:
35+
colors = [list(map(to_rgb, item)) for item in colors]
36+
return colors, labels
37+
38+
39+
def _color_list_to_matrix_and_cmap(colors, axis=0):
40+
if any(issubclass(type(item), list) for item in colors):
41+
all_colors = set(itertools.chain(*colors))
42+
n = len(colors) # number of fields
43+
m = len(colors[0]) # number of observations
44+
else:
45+
all_colors = set(colors)
46+
n = 1
47+
m = len(colors)
48+
colors = [colors]
49+
color_to_value = dict((col, i) for i, col in enumerate(all_colors))
50+
51+
matrix = np.array([color_to_value[c]
52+
for color in colors for c in color])
53+
54+
shape = (n, m)
55+
matrix = matrix.reshape(shape)
56+
if axis == 0:
57+
# row-side:
58+
matrix = matrix.T
59+
60+
cmap = mpl.colors.ListedColormap(all_colors)
61+
return matrix, cmap
62+
63+
64+
def _determine_ticks(labels: Sequence[str]):
65+
num = len(labels)
66+
return [item + .5 for item in range(num)]
67+
68+
69+
def draw_heatmap(colors: Union[pd.DataFrame, pd.Series],
70+
axes: mpl.axes.Axes, axis=0,
71+
index_order: Union[Sequence[Union[str, int]], None] = None, **kwargs):
72+
# index_order reverse
73+
colors = colors.copy()
74+
if index_order is not None:
75+
index_order = list(index_order)[::-1]
76+
colors, labels = _process_colors(colors, index_order)
77+
matrix, cmap = _color_list_to_matrix_and_cmap(colors, axis=axis)
78+
axes.pcolormesh(matrix, cmap=cmap, **kwargs)
79+
if axis == 1:
80+
axes.set_yticks(_determine_ticks(labels))
81+
axes.set_yticklabels(labels)
82+
_ = axes.set_xticks([])
83+
elif axis == 0:
84+
axes.set_xticks(_determine_ticks(labels))
85+
axes.set_xticklabels(labels, rotation=90)
86+
_ = axes.set_yticks([])
87+
for item in ['left', 'right', 'bottom', 'top']:
88+
axes.spines[item].set_visible(False)

dotplot/core.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ class DotPlot(object):
1717
DEFAULT_ITEM_WIDTH = 0.3
1818
DEFAULT_LEGENDS_WIDTH = .45
1919
MIN_FIGURE_HEIGHT = 3
20-
DEFAULT_BAND_ITEM_LENGTH = DEFAULT_ITEM_HEIGHT
20+
DEFAULT_BAND_ITEM_LENGTH = .2
2121

22-
# TODO implement annotation band
2322
def __init__(self, df_size: pd.DataFrame,
2423
df_color: Union[pd.DataFrame, None] = None,
2524
df_circle: Union[pd.DataFrame, None] = None,
26-
df_annotation: Union[pd.DataFrame, None] = None,
25+
row_colors: Union[pd.DataFrame, None] = None,
26+
col_colors: Union[pd.DataFrame, None] = None,
2727
):
2828
"""
2929
Construction a `DotPlot` object from `df_size` and `df_color`
@@ -33,46 +33,71 @@ def __init__(self, df_size: pd.DataFrame,
3333
"""
3434
__slots__ = ['size_data', 'resized_size_data',
3535
'color_data', 'height_item', 'width_item',
36-
'circle_data', 'resized_circle_data', 'annotation_data'
36+
'circle_data', 'resized_circle_data', 'row_colors', 'col_colors'
3737
]
3838
if df_color is not None and df_size.shape != df_color.shape:
3939
raise ValueError('df_size and df_color should have the same dimension')
4040
if df_circle is not None and df_size.shape != df_circle.shape:
4141
raise ValueError('df_size and df_circle should have the same dimension')
42-
if df_annotation is not None and df_size.shape != df_annotation.shape:
43-
raise ValueError('df_size and df_annotation should have the same row number')
42+
if row_colors is not None and df_size.shape[0] != len(row_colors):
43+
raise ValueError('row_colors has the wrong shape')
44+
if col_colors is not None and df_size.shape[1] != len(col_colors):
45+
raise ValueError('col_colors has the wrong shape')
4446

4547
self.size_data = df_size
4648
self.color_data = df_color
4749
self.circle_data = df_circle
4850
self.height_item, self.width_item = df_size.shape
49-
self.annotation_data = df_annotation
51+
# TODO code logic need to argument
52+
self.row_colors = row_colors
53+
self.col_colors = col_colors
5054
self.resized_size_data: Union[pd.DataFrame, None] = None
5155
self.resized_circle_data: Union[pd.DataFrame, None] = None
5256

5357
def __get_figure(self):
58+
"""
59+
Figure layout
60+
:return:
61+
"""
5462
_text_max = math.ceil(self.size_data.index.map(len).max() / 15)
5563
mainplot_height = self.height_item * self.DEFAULT_ITEM_HEIGHT
5664
mainplot_width = (
5765
(_text_max + self.width_item) * self.DEFAULT_ITEM_WIDTH
5866
)
5967
figure_height = max([self.MIN_FIGURE_HEIGHT, mainplot_height])
6068
figure_width = mainplot_width + self.DEFAULT_LEGENDS_WIDTH
61-
if self.annotation_data is not None:
62-
# figure_width = figure_width + self.DEFAULT_BAND_ITEM_LENGTH * self.annotation_data.shape[1]
63-
...
69+
band_width, band_height = 0., 0.
70+
if self.row_colors is not None:
71+
band_width = self.DEFAULT_BAND_ITEM_LENGTH * self.row_colors.shape[1]
72+
if self.col_colors is not None:
73+
band_height = self.DEFAULT_BAND_ITEM_LENGTH * self.col_colors.shape[1]
74+
figure_width = figure_width + band_width
75+
figure_height = figure_height + band_height
76+
6477
plt.style.use('seaborn-white')
6578
fig = plt.figure(figsize=(figure_width, figure_height))
66-
gs = gridspec.GridSpec(nrows=3, ncols=2, wspace=0.15, hspace=0.15,
67-
width_ratios=[mainplot_width, self.DEFAULT_LEGENDS_WIDTH])
68-
ax = fig.add_subplot(gs[:, 0])
69-
ax_cbar = fig.add_subplot(gs[2, 1])
70-
ax_sizes = fig.add_subplot(gs[0, 1])
71-
ax_circles = fig.add_subplot(gs[1, 1])
79+
gs = gridspec.GridSpec(nrows=2, ncols=3, wspace=0.05, hspace=0.02,
80+
width_ratios=[mainplot_width, band_width, self.DEFAULT_LEGENDS_WIDTH],
81+
height_ratios=[band_height, mainplot_height]
82+
)
83+
ax = fig.add_subplot(gs[1, 0])
84+
ax_row_bands = fig.add_subplot(gs[1, 1])
85+
ax_col_bands = fig.add_subplot(gs[0, 0])
86+
ax_abandon = fig.add_subplot(gs[0, 1])
87+
legend_gs = gridspec.GridSpecFromSubplotSpec(3, 1, hspace=.1, subplot_spec=gs[1, 2])
88+
ax_sizes = fig.add_subplot(legend_gs[0, 0])
89+
ax_circles = fig.add_subplot(legend_gs[1, 0])
90+
ax_cbar = fig.add_subplot(legend_gs[2, 0])
91+
92+
_, _ = ax_sizes.axis('off'), ax_circles.axis('off')
93+
if self.col_colors is None:
94+
ax_col_bands.axis('off')
95+
if self.row_colors is None:
96+
ax_row_bands.axis('off')
7297
if self.color_data is None:
7398
ax_cbar.axis('off')
74-
ax_circles.axis('off')
75-
return ax, ax_cbar, ax_sizes, ax_circles, fig
99+
ax_abandon.axis('off')
100+
return ax, ax_cbar, ax_sizes, ax_circles, ax_row_bands, ax_col_bands, fig
76101

77102
@classmethod
78103
def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str, sizes_key: str,
@@ -126,7 +151,7 @@ class method for conveniently constructing DotPlot from tidy data
126151
circle_df = data_frame.loc[:, data_frame.columns.str.startswith(circle_key)]
127152
return cls(sizes_df, color_df, circle_df)
128153

129-
def __get_coordinates(self, size_factor):
154+
def __get_coordinates(self):
130155
X = list(range(1, self.width_item + 1)) * self.height_item
131156
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
132157
return X, Y
@@ -138,7 +163,7 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax, **kws):
138163
for _value in ['dot_title', 'circle_title', 'colorbar_title', 'dot_color', 'circle_color']:
139164
_ = kws.pop(_value, None)
140165

141-
X, Y = self.__get_coordinates(size_factor)
166+
X, Y = self.__get_coordinates()
142167
if self.color_data is None:
143168
sct = ax.scatter(X, Y, c=dot_color, s=self.resized_size_data.values.flatten(),
144169
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap, **kws)
@@ -234,7 +259,9 @@ def plot(self, size_factor: float = 15,
234259
path: Union[PathLike, None] = None,
235260
cmap: Union[str, mpl.colors.Colormap] = 'Reds',
236261
cluster_row: bool = False, cluster_col: bool = False,
237-
cluster_kws: Union[Dict, None] = None, **kwargs
262+
cluster_kws: Union[Dict, None] = None,
263+
color_band_kws: Union[Dict, None] = None,
264+
**kwargs
238265
):
239266
"""
240267
@@ -248,12 +275,13 @@ def plot(self, size_factor: float = 15,
248275
:param cluster_kws, key args for cluster, including `cluster_method`, `cluster_metric`, 'cluster_n'
249276
:param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
250277
other kwargs are passed to `matplotlib.Axes.scatter`
278+
:param color_band_kws: this kwargs was passed to `matplotlib.axes.Axes.pcolormesh`
251279
:return:
252280
"""
253281
self.__preprocess_data(size_factor, cluster_row=cluster_row, cluster_col=cluster_col,
254282
**cluster_kws if cluster_kws is not None else {}
255283
)
256-
ax, ax_cbar, ax_sizes, ax_circles, fig = self.__get_figure()
284+
ax, ax_cbar, ax_sizes, ax_circles, ax_row_bands, ax_col_bands, fig = self.__get_figure()
257285
scatter, sct_circle = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
258286
self.__draw_legend(ax_sizes, scatter, size_factor,
259287
color=kwargs.get('dot_color', '#58000C'), # dot legend color
@@ -266,8 +294,20 @@ def plot(self, size_factor: float = 15,
266294
if self.color_data is not None:
267295
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax,
268296
ylabel=kwargs.get('colorbar_title', '-log10(pvalue)'))
297+
298+
if self.col_colors is not None:
299+
from .annotation_bands import draw_heatmap
300+
color_band_kws = {} if color_band_kws is None else color_band_kws
301+
draw_heatmap(self.col_colors, axes=ax_col_bands,
302+
index_order=self.size_data.columns.tolist(), axis=1, **color_band_kws)
303+
if self.row_colors is not None:
304+
color_band_kws = {} if color_band_kws is None else color_band_kws
305+
from .annotation_bands import draw_heatmap
306+
draw_heatmap(self.row_colors, axes=ax_row_bands,
307+
index_order=self.size_data.index.tolist(), axis=0, **color_band_kws)
308+
269309
if path:
270-
fig.savefig(path, dpi=300, bbox_inches='tight') #
310+
fig.savefig(path, dpi=300, bbox_inches='tight')
271311
return scatter
272312

273313
def __str__(self):

0 commit comments

Comments
 (0)