Skip to content

Commit 0ce7e8b

Browse files
committed
update rough implementation of dotplot
1 parent b47e011 commit 0ce7e8b

4 files changed

Lines changed: 330 additions & 0 deletions

File tree

.gitignore

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
# custom
132+
job*
133+
script*
134+
data/
135+
.idea/

dotplot/__init__.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import math
2+
from os import PathLike
3+
from typing import Union, Sequence, Callable
4+
5+
import matplotlib as mpl
6+
import numpy as np
7+
import pandas as pd
8+
from matplotlib import gridspec
9+
from matplotlib import pyplot as plt
10+
11+
12+
class DotPlot(object):
13+
def __init__(self, df_size: pd.DataFrame,
14+
df_color: Union[pd.DataFrame, None] = None,
15+
):
16+
"""
17+
Construction a `DotPlot` object from `df_size` and `df_color`
18+
19+
:param df_size: the DataFrame object represents the scatter size in dotplot
20+
:param df_color: the DataFrame object represents the color in dotplot
21+
"""
22+
if (df_color is not None) & (df_size.shape != df_color.shape):
23+
raise ValueError('df_size and df_color should have the same dimension')
24+
self.size_data = df_size
25+
self.color_data = df_color
26+
self.height, self.width = df_size.shape
27+
self.resized_size_data: pd.DataFrame
28+
29+
@classmethod
30+
def parse_from_tidy_data(cls, data_frame: pd.DataFrame, item_key: str, group_key: str,
31+
sizes_key: str, color_key: str, selected_item: Union[None, Sequence] = None, *,
32+
sizes_func: Union[None, Callable] = None, color_func: Union[None, Callable] = None
33+
):
34+
"""
35+
36+
class method for conveniently constructing DotPlot from tidy data
37+
38+
:param data_frame:
39+
:param item_key:
40+
:param group_key:
41+
:param sizes_key:
42+
:param color_key:
43+
:param selected_item: default None, if specified, this should be subsets of `item_key` in `data_frame`
44+
:param sizes_func:
45+
:param color_func:
46+
:return:
47+
"""
48+
data_frame = data_frame[[item_key, group_key, sizes_key, color_key]]
49+
if sizes_func is not None:
50+
data_frame[sizes_key] = data_frame[sizes_key].map(sizes_func)
51+
if color_func is not None:
52+
data_frame[color_key] = data_frame[color_key].map(color_func)
53+
if selected_item is not None:
54+
data_frame = data_frame(data_frame.term_key.isin(selected_item))
55+
56+
data_frame = data_frame.pivot(index=item_key, columns=group_key, values=[color_key, sizes_key])
57+
data_frame.columns = data_frame.columns.map(lambda x: '_'.join(x))
58+
data_frame = data_frame.fillna(0)
59+
color_df = data_frame.loc[:, data_frame.columns.str.startswith(color_key)]
60+
sizes_df = data_frame.loc[:, data_frame.columns.str.startswith(sizes_key)]
61+
color_df.columns = color_df.columns.map(lambda x: '_'.join(x.split('_')[1:]))
62+
sizes_df.columns = sizes_df.columns.map(lambda x: '_'.join(x.split('_')[1:]))
63+
return cls(color_df, sizes_df)
64+
65+
def __determine_figsize(self, **kwargs):
66+
width_factor = kwargs.get('width_factor', 4)
67+
height_factor = kwargs.get('height_factor', 0.45)
68+
fig_width, fig_height = width_factor * self.width, height_factor * self.height
69+
fig_width = fig_width / 9 * 10
70+
return fig_width, fig_height
71+
72+
def __get_figure_layout(self, **kwargs):
73+
fig_width, fig_height = self.__determine_figsize(**kwargs)
74+
plt.style.use('seaborn-white')
75+
fig = plt.figure(figsize=(fig_width, fig_height))
76+
gs = gridspec.GridSpec(nrows=2, ncols=10, wspace=0.4, hspace=0.1)
77+
ax = fig.add_subplot(gs[:, :-4])
78+
ax_cbar = fig.add_subplot(gs[1, -4:-3])
79+
ax_legend = fig.add_subplot(gs[0, -4:])
80+
return ax, ax_cbar, ax_legend, fig
81+
82+
def __get_coordinates(self, size_factor):
83+
X = list(range(1, self.width + 1)) * self.height
84+
Y = sorted(list(range(1, self.height + 1)) * self.width)
85+
self.resized_size_data = self.size_data.applymap(func=lambda x: x * size_factor)
86+
return X, Y
87+
88+
def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
89+
X, Y = self.__get_coordinates(size_factor)
90+
if self.color_data is None:
91+
sct = ax.scatter(X, Y, c='r', cmap=cmap, s=self.resized_size_data.values.flatten(),
92+
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax)
93+
else:
94+
sct = ax.scatter(X, Y, c=self.color_data.values.flatten(), s=self.resized_size_data.values.flatten(),
95+
edgecolors='none', linewidths=0, vmin=vmin, vmax=vmax, cmap=cmap)
96+
width, height = self.width, self.height
97+
ax.set_xlim([0.5, width + 0.5])
98+
ax.set_ylim([0.6, height + 0.6])
99+
ax.set_xticks(range(1, width + 1))
100+
ax.set_yticks(range(1, height + 1))
101+
ax.set_xticklabels(self.size_data.columns.tolist(), rotation='vertical')
102+
ax.set_yticklabels(self.size_data.index.tolist())
103+
ax.tick_params(axis='y', length=5, labelsize=15, direction='out')
104+
ax.tick_params(axis='x', length=5, labelsize=15, direction='out')
105+
return sct
106+
107+
@staticmethod
108+
def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax):
109+
gradient = np.linspace(1, 0, 500)
110+
gradient = gradient[:, np.newaxis]
111+
im = ax.imshow(gradient, aspect='auto', cmap=cmap, origin='upper', extent=[.2, 0.3, 0.5, -0.5])
112+
ax.set_xticks([])
113+
ax.set_yticks([])
114+
ax_cbar2 = ax.twinx()
115+
_ = ax_cbar2.set_yticks([0, 1000])
116+
if vmax is None:
117+
vmax = math.ceil(sct.get_array().max())
118+
if vmin is None:
119+
vmin = math.floor(sct.get_array().min())
120+
_ = ax_cbar2.set_yticklabels([vmin, vmax])
121+
_ = ax_cbar2.set_ylabel('-log10(pvalue)')
122+
123+
@staticmethod
124+
def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor):
125+
handles, labels = sct.legend_elements(prop="sizes", alpha=1,
126+
func=lambda x: x / size_factor,
127+
color='#58000C')
128+
if len(handles) > 3:
129+
handles = np.asarray(handles)
130+
labels = np.asarray(labels)
131+
handles = handles[[0, math.ceil(len(handles) / 2), -1]]
132+
labels = labels[[0, math.ceil(len(labels) / 2), -1]]
133+
_ = ax.legend(handles, labels, title="Sizes", loc='center left') # bbox_to_anchor=(0.9, 0., 0.4, 0.4)
134+
ax.set_xticks([])
135+
ax.set_yticks([])
136+
ax.spines['top'].set_visible(False)
137+
ax.spines['bottom'].set_visible(False)
138+
ax.spines['left'].set_visible(False)
139+
ax.spines['right'].set_visible(False)
140+
141+
def plot(self, size_factor: float = 15,
142+
vmin: float = 0, vmax: float = None,
143+
path: Union[PathLike, None] = None,
144+
cmap: Union[str, mpl.colors.Colormap] = 'Reds',
145+
**kwargs):
146+
"""
147+
148+
:param size_factor: `size factor` * `value` for the actually representation of scatter size in the final figure
149+
:param vmin: `vmin` in `matplotlib.pyplot.scatter`
150+
:param vmax: `vmax` in `matplotlib.pyplot.scatter`
151+
:param path: path to save the figure
152+
:param cmap: color map supported by matplotlib
153+
:param kwargs:
154+
:return:
155+
"""
156+
ax, ax_cbar, ax_legend, fig = self.__get_figure_layout(**kwargs)
157+
scatter = self.__draw_dotplot(ax, size_factor, cmap, vmin, vmax)
158+
self.__draw_legend(ax_legend, scatter, size_factor)
159+
self.__draw_color_bar(ax_cbar, scatter, cmap, vmin, vmax)
160+
plt.subplots_adjust(left=0.75)
161+
if path:
162+
fig.savefig(path, dpi=300)
163+
return scatter
164+
165+
def __str__(self):
166+
return 'DotPlot object with data point in shape %s' % str(self.size_data.shape)
167+
168+
__repr__ = __str__

dotplot/cmap.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import matplotlib as mpl
2+
3+
4+
def get_colormap(color_list: list, segment=1000):
5+
return mpl.colors.LinearSegmentedColormap.from_list('color_list', color_list, N=segment)

dotplot/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
5+
def parse_from_clusterprofile(cls, dataframes, groups, term_list=None):
6+
assert len(dataframes) == len(groups)
7+
merged_df = None
8+
for _dataframe, _group in zip(dataframes, groups):
9+
if term_list is None:
10+
_sub_df = _dataframe
11+
else:
12+
_sub_df = _dataframe[_dataframe.index.isin(term_list)]
13+
if not _sub_df.empty:
14+
_sub_df['group'] = _group
15+
if merged_df is not None:
16+
merged_df = pd.concat((merged_df, _sub_df))
17+
else:
18+
merged_df = _sub_df
19+
merged_df = merged_df[['Description', 'pvalue', 'GeneRatio', 'group']]
20+
merged_df['GeneRatio'] = merged_df.GeneRatio.map(lambda x: int(x.split('/')[0]))
21+
merged_df['pvalue'] = merged_df['pvalue'].map(lambda x: -np.log10(x))
22+
return merged_df

0 commit comments

Comments
 (0)