Skip to content

Commit eac166d

Browse files
committed
add circles plot, fixed #3
1 parent 4b78b2e commit eac166d

1 file changed

Lines changed: 52 additions & 19 deletions

File tree

dotplot/__init__.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from matplotlib import gridspec
99
from matplotlib import pyplot as plt
1010

11+
mpl.rcParams['pdf.fonttype'] = 42
12+
mpl.rcParams["font.sans-serif"] = "Arial"
13+
1114

1215
class DotPlot(object):
1316
DEFAULT_ITEM_HEIGHT = 0.3
@@ -17,20 +20,26 @@ class DotPlot(object):
1720

1821
def __init__(self, df_size: pd.DataFrame,
1922
df_color: Union[pd.DataFrame, None] = None,
23+
df_circle: Union[pd.DataFrame, None] = None,
2024
):
2125
"""
2226
Construction a `DotPlot` object from `df_size` and `df_color`
2327
2428
:param df_size: the DataFrame object represents the scatter size in dotplot
2529
:param df_color: the DataFrame object represents the color in dotplot
2630
"""
27-
__slots__ = ['size_data', 'color_data', 'height_item', 'width_item', 'resized_size_data']
31+
__slots__ = ['size_data', 'resized_size_data',
32+
'color_data', 'height_item', 'width_item',
33+
'circle_data', 'resized_circle_data'
34+
]
2835
if (df_color is not None) & (df_size.shape != df_color.shape):
2936
raise ValueError('df_size and df_color should have the same dimension')
3037
self.size_data = df_size
3138
self.color_data = df_color
39+
self.circle_data = df_circle
3240
self.height_item, self.width_item = df_size.shape
3341
self.resized_size_data: pd.DataFrame
42+
self.resized_circle_data: pd.DataFrame
3443

3544
def __get_figure(self):
3645
_text_max = math.ceil(self.size_data.index.map(len).max() / 15)
@@ -46,12 +55,14 @@ def __get_figure(self):
4655
width_ratios=[mainplot_width, self.DEFAULT_LEGENDS_WIDTH])
4756
ax = fig.add_subplot(gs[:, 0])
4857
ax_cbar = fig.add_subplot(gs[2, 1])
49-
ax_legend = fig.add_subplot(gs[0:2, 1])
50-
return ax, ax_cbar, ax_legend, fig
58+
ax_sizes = fig.add_subplot(gs[0, 1])
59+
ax_circles = fig.add_subplot(gs[1, 1])
60+
return ax, ax_cbar, ax_sizes, ax_circles, fig
5161

5262
@classmethod
5363
def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str, sizes_key: str,
54-
color_key: str, selected_item: Union[None, Sequence] = None,
64+
color_key: Union[None, str] = None, circle_key: Union[None, str] = None,
65+
selected_item: Union[None, Sequence] = None,
5566
selected_group: Union[None, Sequence] = None, *,
5667
sizes_func: Union[None, Callable] = None, color_func: Union[None, Callable] = None
5768
):
@@ -69,34 +80,41 @@ class method for conveniently constructing DotPlot from tidy data
6980
:param selected_group: Same as `selected_item`, for group order and subset groups
7081
:param sizes_func:
7182
:param color_func:
83+
:param circle_key:
7284
:return:
7385
"""
74-
data_frame = data_frame[[item_key, group_key, sizes_key, color_key]]
86+
keys = [v for v in [item_key, group_key, sizes_key, color_key, circle_key] if v is not None]
87+
data_frame = data_frame[keys]
7588
_original_item_order = data_frame[item_key].tolist()
7689
_original_item_order = _original_item_order[::-1]
7790
if sizes_func is not None:
7891
data_frame[sizes_key] = data_frame[sizes_key].map(sizes_func)
7992
if color_func is not None:
8093
data_frame[color_key] = data_frame[color_key].map(color_func)
81-
data_frame = data_frame.pivot(index=item_key, columns=group_key, values=[color_key, sizes_key])
94+
keys.remove(item_key)
95+
keys.remove(group_key)
96+
data_frame = data_frame.pivot(index=item_key, columns=group_key, values=keys)
8297
data_frame = data_frame.loc[_original_item_order, :]
8398
if selected_item is not None:
8499
data_frame = data_frame.loc[selected_item, :]
85100
if selected_group is not None:
86101
data_frame = data_frame.loc[:, selected_group]
87-
88102
data_frame.columns = data_frame.columns.map(lambda x: '_'.join(x))
89103
data_frame = data_frame.fillna(0)
90-
color_df = data_frame.loc[:, data_frame.columns.str.startswith(color_key)]
104+
105+
sizes_df, color_df, circle_df = (None, None, None)
91106
sizes_df = data_frame.loc[:, data_frame.columns.str.startswith(sizes_key)]
92-
color_df.columns = color_df.columns.map(lambda x: '_'.join(x.split('_')[1:]))
93-
sizes_df.columns = sizes_df.columns.map(lambda x: '_'.join(x.split('_')[1:]))
94-
return cls(color_df, sizes_df)
107+
if color_key is not None:
108+
color_df = data_frame.loc[:, data_frame.columns.str.startswith(color_key)]
109+
if circle_key is not None:
110+
circle_df = data_frame.loc[:, data_frame.columns.str.startswith(circle_key)]
111+
return cls(sizes_df, color_df, circle_df)
95112

96113
def __get_coordinates(self, size_factor):
97114
X = list(range(1, self.width_item + 1)) * self.height_item
98115
Y = sorted(list(range(1, self.height_item + 1)) * self.width_item)
99116
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
117+
self.resized_circle_data = self.circle_data.applymap(func=lambda x: x * size_factor)
100118
return X, Y
101119

102120
def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
@@ -107,6 +125,10 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
107125
else:
108126
sct = ax.scatter(X, Y, c=self.color_data.values.flatten(), s=self.resized_size_data.values.flatten(),
109127
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap)
128+
sct_circle = None
129+
if self.circle_data is not None:
130+
sct_circle = ax.scatter(X, Y, c='', edgecolors='k', marker='o', linestyle='--',
131+
s=self.resized_circle_data.values.flatten())
110132
width, height = self.width_item, self.height_item
111133
ax.set_xlim([0.5, width + 0.5])
112134
ax.set_ylim([0.6, height + 0.6])
@@ -116,13 +138,13 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
116138
ax.set_yticklabels(self.size_data.index.tolist())
117139
ax.tick_params(axis='y', length=5, labelsize=15, direction='out')
118140
ax.tick_params(axis='x', length=5, labelsize=15, direction='out')
119-
return sct
141+
return sct, sct_circle
120142

121143
@staticmethod
122144
def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax):
123145
gradient = np.linspace(1, 0, 500)
124146
gradient = gradient[:, np.newaxis]
125-
im = ax.imshow(gradient, aspect='auto', cmap=cmap, origin='upper', extent=[.2, 0.3, 0.5, -0.5])
147+
_ = ax.imshow(gradient, aspect='auto', cmap=cmap, origin='upper', extent=[.2, 0.3, 0.5, -0.5])
126148
ax.set_xticks([])
127149
ax.set_yticks([])
128150
ax_cbar2 = ax.twinx()
@@ -135,16 +157,25 @@ def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax):
135157
_ = ax_cbar2.set_ylabel('-log10(pvalue)')
136158

137159
@staticmethod
138-
def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor):
160+
def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, circle=False, color=None):
161+
print(id(sct))
139162
handles, labels = sct.legend_elements(prop="sizes", alpha=1,
140163
func=lambda x: x / size_factor,
141-
color='#58000C')
164+
color=color
165+
)
142166
if len(handles) > 3:
143167
handles = np.asarray(handles)
144168
labels = np.asarray(labels)
145169
handles = handles[[0, math.ceil(len(handles) / 2), -1]]
146170
labels = labels[[0, math.ceil(len(labels) / 2), -1]]
147-
_ = ax.legend(handles, labels, title="Sizes", loc='center left') # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
171+
if circle:
172+
from matplotlib.lines import Line2D
173+
for i, _item in enumerate(handles):
174+
xdata, ydata = _item.get_data()
175+
marker_size = _item.get_markersize()
176+
handles[i] = Line2D(xdata, ydata, color='white', marker='$\u25CC$',
177+
markeredgecolor=color, markersize=marker_size)
178+
_ = ax.legend(handles, labels, title=title, loc='center left') # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
148179
ax.set_xticks([])
149180
ax.set_yticks([])
150181
ax.spines['top'].set_visible(False)
@@ -165,9 +196,11 @@ def plot(self, size_factor: float = 15,
165196
:param cmap: color map supported by matplotlib
166197
:return:
167198
"""
168-
ax, ax_cbar, ax_legend, fig = self.__get_figure()
169-
scatter = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
170-
self.__draw_legend(ax_legend, scatter, size_factor)
199+
ax, ax_cbar, ax_sizes, ax_circles, fig = self.__get_figure()
200+
scatter, sct_circle = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
201+
self.__draw_legend(ax_sizes, scatter, size_factor, title='Sizes', color='#58000C')
202+
if sct_circle is not None:
203+
self.__draw_legend(ax_circles, sct_circle, size_factor, title='Circles', circle=True, color='k')
171204
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax)
172205
if path:
173206
fig.savefig(path, dpi=300, bbox_inches='tight') #

0 commit comments

Comments
 (0)