Skip to content

Commit 8bb8df8

Browse files
committed
spatial stuff
1 parent 008740f commit 8bb8df8

8 files changed

Lines changed: 1353 additions & 159 deletions

File tree

src/pySingleCellNet/plotting/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from .helpers import (
2+
make_bivariate_cmap
3+
)
4+
15
from .bar import (
26
bar_compare_celltype_composition,
37
stackedbar_composition,
@@ -7,7 +11,12 @@
711
bar_classifier_f1,
812
)
913

10-
from .dot import (
14+
from .spatial import (
15+
spatial_contours,
16+
spatial_two_genes
17+
)
18+
19+
from .dot import (
1120
umi_counts_ranked,
1221
ontogeny_graph,
1322
dotplot_deg,
@@ -29,6 +38,10 @@
2938

3039
# API
3140
__all__ = [
41+
"scatter_genes_oneper",
42+
"spatial_contours",
43+
"make_bivariate_cmap",
44+
"spatial_two_genes",
3245
"bar_compare_celltype_composition",
3346
"stackedbar_composition",
3447
"stackedbar_composition_list",

src/pySingleCellNet/plotting/dot.py

Lines changed: 7 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import umap
99
import anndata as ad
1010
import igraph as ig
11-
from palettable.colorbrewer.qualitative import Set2_6
12-
from palettable.tableau import GreenOrange_6
13-
from palettable.cartocolors.qualitative import Safe_6
14-
from palettable.cartocolors.qualitative import Vivid_4
15-
from palettable.cartocolors.qualitative import Vivid_6
16-
from palettable.cartocolors.qualitative import Vivid_10
17-
from palettable.scientific.diverging import Roma_20
11+
# from palettable.colorbrewer.qualitative import Set2_6
12+
# from palettable.tableau import GreenOrange_6
13+
# from palettable.cartocolors.qualitative import Safe_6
14+
# from palettable.cartocolors.qualitative import Vivid_4
15+
# from palettable.cartocolors.qualitative import Vivid_6
16+
# from palettable.cartocolors.qualitative import Vivid_10
17+
# from palettable.scientific.diverging import Roma_20
1818
from palettable.scientific.sequential import LaJolla_20
1919
from palettable.scientific.sequential import Batlow_20
2020
from anndata import AnnData
@@ -23,103 +23,6 @@
2323
from typing import Optional, Callable
2424
from ..utils import *
2525

26-
from scipy import sparse
27-
from typing import Optional, Callable
28-
29-
def spatial_two_genes(
30-
adata: AnnData,
31-
gene1: str,
32-
gene2: str,
33-
title: Optional[str] = None,
34-
scale_max_value: float = 2.0,
35-
spot_size: float = 35,
36-
cmap: str = 'RdBu_r',
37-
alpha: float = 0.5,
38-
copy_adata: bool = True,
39-
combine_fun: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None,
40-
**plot_kwargs
41-
) -> None:
42-
"""Plot a custom combination of two gene expressions on a spatial scatter,
43-
scaling only those two genes.
44-
45-
This will (optionally) make a copy of your AnnData, extract expression for
46-
gene1 and gene2, scale each gene vector (zero‐center, unit variance,
47-
clipped to ±scale_max_value), compute a per‐cell combined metric, store it
48-
in `.obs[title]`, and then call `sc.pl.spatial`.
49-
50-
Args:
51-
adata: AnnData with spatial coords and expression in `.X` (or `.layers`).
52-
gene1: Name of the first gene (must be in `adata.var_names`).
53-
gene2: Name of the second gene.
54-
title: Key under which to store the combined metric in `adata.obs` and
55-
plot title. Defaults to `"gene1_gene2"`.
56-
scale_max_value: Maximum absolute value to clip the scaled gene vectors.
57-
Defaults to 2.0.
58-
spot_size: Passed to `sc.pl.spatial(..., spot_size=...)`. Default 35.
59-
cmap: Colormap for `sc.pl.spatial`. Default `'RdBu_r'`.
60-
alpha: Spot transparency for `sc.pl.spatial`. Default 0.5.
61-
copy_adata: If True, operate on a copy. Otherwise overwrite `adata.obs`.
62-
combine_fun: Function `(g1, g2) -> combined`. If None, uses
63-
`(g1 * g2) + g1 - g2`.
64-
**plot_kwargs: Any additional args forwarded to `sc.pl.spatial`.
65-
66-
Returns:
67-
None. Displays a spatial scatter of the combined metric.
68-
"""
69-
# determine obs key / title
70-
if title is None:
71-
title = f"{gene1}_{gene2}"
72-
73-
# copy or in-place
74-
ad = adata.copy() if copy_adata else adata
75-
76-
# extract raw expression
77-
X1 = ad[:, gene1].X
78-
X2 = ad[:, gene2].X
79-
80-
# to 1D numpy arrays
81-
def to_array(mat):
82-
if sparse.issparse(mat):
83-
arr = mat.A.flatten()
84-
else:
85-
arr = np.asarray(mat).flatten()
86-
return arr
87-
88-
g1 = to_array(X1)
89-
g2 = to_array(X2)
90-
91-
# scale each gene individually
92-
def scale_vec(x):
93-
m = x.mean()
94-
s = x.std(ddof=0)
95-
if s == 0:
96-
# avoid divide by zero
97-
scaled = x - m
98-
else:
99-
scaled = (x - m) / s
100-
return np.clip(scaled, -scale_max_value, scale_max_value)
101-
102-
g1_scaled = scale_vec(g1)
103-
g2_scaled = scale_vec(g2)
104-
105-
# combine
106-
if combine_fun is None:
107-
expr = (g1_scaled * g2_scaled) + g1_scaled - g2_scaled
108-
else:
109-
expr = combine_fun(g1_scaled, g2_scaled)
110-
111-
# store and plot
112-
ad.obs[title] = pd.Series(expr, index=ad.obs.index)
113-
sc.pl.spatial(
114-
ad,
115-
color=title,
116-
spot_size=spot_size,
117-
cmap=cmap,
118-
alpha=alpha,
119-
**plot_kwargs
120-
)
121-
122-
12326

12427

12528
def umi_counts_ranked(adata, total_counts_column="total_counts"):
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import numpy as np
2+
from matplotlib.colors import ListedColormap, to_rgb
3+
import matplotlib.tri as mtri
4+
from scipy.interpolate import griddata
5+
from scipy.ndimage import gaussian_filter
6+
from matplotlib.colors import to_hex
7+
import matplotlib.pyplot as plt
8+
9+
def _smooth_contour(
10+
x: np.ndarray,
11+
y: np.ndarray,
12+
z: np.ndarray,
13+
levels: int = 6,
14+
grid_res: int = 200,
15+
smooth_sigma: float = 2,
16+
contour_kwargs: dict = None
17+
):
18+
"""Overlay smooth contour lines by gridding + Gaussian blur.
19+
20+
Args:
21+
x, y: 1D arrays of spatial coordinates (length n_obs).
22+
z: 1D array of normalized or summarized expression (length n_obs).
23+
levels: Number of contour levels or list of levels.
24+
grid_res: Resolution of the regular grid along each axis.
25+
smooth_sigma: Sigma for Gaussian filter to smooth the gridded field.
26+
contour_kwargs: Extra kwargs passed to plt.contour (e.g. colors, linewidths).
27+
28+
Returns:
29+
The contour set drawn on the current axes.
30+
"""
31+
# 1) create regular grid
32+
xi = np.linspace(x.min(), x.max(), grid_res)
33+
yi = np.linspace(y.min(), y.max(), grid_res)
34+
Xi, Yi = np.meshgrid(xi, yi)
35+
36+
# 2) interpolate scattered z onto the grid
37+
Zi = griddata((x, y), z, (Xi, Yi), method='cubic', fill_value=np.nan)
38+
39+
# 3) smooth the gridded values
40+
Zi_s = gaussian_filter(Zi, sigma=smooth_sigma, mode='nearest')
41+
42+
# 4) draw contours
43+
ctr_kw = {} if contour_kwargs is None else contour_kwargs
44+
cs = plt.contour(Xi, Yi, Zi_s, levels=levels, **ctr_kw)
45+
plt.clabel(cs, inline=True, fontsize=8)
46+
return cs
47+
48+
49+
def make_bivariate_cmap(
50+
c00: str = "#f0f0f0",
51+
c10: str = "#e31a1c",
52+
c01: str = "#1f78b4",
53+
c11: str = "#ffff00",
54+
n: int = 128
55+
) -> ListedColormap:
56+
"""Create a bivariate colormap by bilinear‐interpolating four corner colors.
57+
58+
This builds an (n × n) grid of RGB colors, blending smoothly between
59+
the specified corner colors:
60+
- c00 at (low, low)
61+
- c10 at (high, low)
62+
- c01 at (low, high)
63+
- c11 at (high, high)
64+
65+
Args:
66+
c00: Matplotlib color spec (hex, name, or RGB tuple) for the low/low corner.
67+
c10: Color for the high/low corner.
68+
c01: Color for the low/high corner.
69+
c11: Color for the high/high corner.
70+
n: Resolution per axis. The total length of the returned colormap is n*n.
71+
72+
Returns:
73+
ListedColormap: A colormap with n*n entries blending between the four corners.
74+
"""
75+
# Convert corner colors to RGB arrays
76+
corners = {
77+
(0, 0): np.array(to_rgb(c00)),
78+
(1, 0): np.array(to_rgb(c10)),
79+
(0, 1): np.array(to_rgb(c01)),
80+
(1, 1): np.array(to_rgb(c11)),
81+
}
82+
83+
# Build an (n, n, 3) grid by bilinear interpolation
84+
lut = np.zeros((n, n, 3), dtype=float)
85+
xs = np.linspace(0, 1, n)
86+
ys = np.linspace(0, 1, n)
87+
for j, y in enumerate(ys):
88+
for i, x in enumerate(xs):
89+
lut[j, i] = (
90+
corners[(0, 0)] * (1 - x) * (1 - y) +
91+
corners[(1, 0)] * x * (1 - y) +
92+
corners[(0, 1)] * (1 - x) * y +
93+
corners[(1, 1)] * x * y
94+
)
95+
96+
# Flatten to (n*n, 3) and return as a ListedColormap
97+
return ListedColormap(lut.reshape(n * n, 3))

0 commit comments

Comments
 (0)