|
19 | 19 | from palettable.scientific.sequential import Batlow_20 |
20 | 20 | from anndata import AnnData |
21 | 21 | from scipy.sparse import csr_matrix |
22 | | -from sklearn.metrics import f1_score |
| 22 | +from scipy import sparse |
| 23 | +from typing import Optional, Callable |
23 | 24 | from ..utils import * |
24 | 25 |
|
| 26 | +from scipy import sparse |
| 27 | +from typing import Optional, Callable |
| 28 | + |
| 29 | +def spatial_two_genes( |
| 30 | + adata: AnnData, |
| 31 | + gene1: str, |
| 32 | + gene2: str, |
| 33 | + title: Optional[str] = None, |
| 34 | + scale_max_value: float = 2.0, |
| 35 | + spot_size: float = 35, |
| 36 | + cmap: str = 'RdBu_r', |
| 37 | + alpha: float = 0.5, |
| 38 | + copy_adata: bool = True, |
| 39 | + combine_fun: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None, |
| 40 | + **plot_kwargs |
| 41 | +) -> None: |
| 42 | + """Plot a custom combination of two gene expressions on a spatial scatter, |
| 43 | + scaling only those two genes. |
| 44 | +
|
| 45 | + This will (optionally) make a copy of your AnnData, extract expression for |
| 46 | + gene1 and gene2, scale each gene vector (zero‐center, unit variance, |
| 47 | + clipped to ±scale_max_value), compute a per‐cell combined metric, store it |
| 48 | + in `.obs[title]`, and then call `sc.pl.spatial`. |
| 49 | +
|
| 50 | + Args: |
| 51 | + adata: AnnData with spatial coords and expression in `.X` (or `.layers`). |
| 52 | + gene1: Name of the first gene (must be in `adata.var_names`). |
| 53 | + gene2: Name of the second gene. |
| 54 | + title: Key under which to store the combined metric in `adata.obs` and |
| 55 | + plot title. Defaults to `"gene1_gene2"`. |
| 56 | + scale_max_value: Maximum absolute value to clip the scaled gene vectors. |
| 57 | + Defaults to 2.0. |
| 58 | + spot_size: Passed to `sc.pl.spatial(..., spot_size=...)`. Default 35. |
| 59 | + cmap: Colormap for `sc.pl.spatial`. Default `'RdBu_r'`. |
| 60 | + alpha: Spot transparency for `sc.pl.spatial`. Default 0.5. |
| 61 | + copy_adata: If True, operate on a copy. Otherwise overwrite `adata.obs`. |
| 62 | + combine_fun: Function `(g1, g2) -> combined`. If None, uses |
| 63 | + `(g1 * g2) + g1 - g2`. |
| 64 | + **plot_kwargs: Any additional args forwarded to `sc.pl.spatial`. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + None. Displays a spatial scatter of the combined metric. |
| 68 | + """ |
| 69 | + # determine obs key / title |
| 70 | + if title is None: |
| 71 | + title = f"{gene1}_{gene2}" |
| 72 | + |
| 73 | + # copy or in-place |
| 74 | + ad = adata.copy() if copy_adata else adata |
| 75 | + |
| 76 | + # extract raw expression |
| 77 | + X1 = ad[:, gene1].X |
| 78 | + X2 = ad[:, gene2].X |
| 79 | + |
| 80 | + # to 1D numpy arrays |
| 81 | + def to_array(mat): |
| 82 | + if sparse.issparse(mat): |
| 83 | + arr = mat.A.flatten() |
| 84 | + else: |
| 85 | + arr = np.asarray(mat).flatten() |
| 86 | + return arr |
| 87 | + |
| 88 | + g1 = to_array(X1) |
| 89 | + g2 = to_array(X2) |
| 90 | + |
| 91 | + # scale each gene individually |
| 92 | + def scale_vec(x): |
| 93 | + m = x.mean() |
| 94 | + s = x.std(ddof=0) |
| 95 | + if s == 0: |
| 96 | + # avoid divide by zero |
| 97 | + scaled = x - m |
| 98 | + else: |
| 99 | + scaled = (x - m) / s |
| 100 | + return np.clip(scaled, -scale_max_value, scale_max_value) |
| 101 | + |
| 102 | + g1_scaled = scale_vec(g1) |
| 103 | + g2_scaled = scale_vec(g2) |
| 104 | + |
| 105 | + # combine |
| 106 | + if combine_fun is None: |
| 107 | + expr = (g1_scaled * g2_scaled) + g1_scaled - g2_scaled |
| 108 | + else: |
| 109 | + expr = combine_fun(g1_scaled, g2_scaled) |
| 110 | + |
| 111 | + # store and plot |
| 112 | + ad.obs[title] = pd.Series(expr, index=ad.obs.index) |
| 113 | + sc.pl.spatial( |
| 114 | + ad, |
| 115 | + color=title, |
| 116 | + spot_size=spot_size, |
| 117 | + cmap=cmap, |
| 118 | + alpha=alpha, |
| 119 | + **plot_kwargs |
| 120 | + ) |
| 121 | + |
| 122 | + |
| 123 | + |
| 124 | + |
25 | 125 | def umi_counts_ranked(adata, total_counts_column="total_counts"): |
26 | 126 | """ |
27 | 127 | Identifies and plors the knee point of the UMI count distribution in an AnnData object. |
|
0 commit comments