|
12 | 12 |
|
13 | 13 |
|
14 | 14 | def cluster_subclusters( |
15 | | - adata, |
| 15 | + adata: ad.AnnData, |
16 | 16 | cluster_column: str = 'leiden', |
| 17 | + cluster_name: str = None, |
| 18 | + layer: str = 'counts', |
17 | 19 | n_hvg: int = 2000, |
18 | 20 | n_pcs: int = 40, |
19 | 21 | n_neighbors: int = 10, |
20 | | - leiden_resolution: float = 0.25 |
21 | | -) -> sc.AnnData: |
| 22 | + leiden_resolution: float = 0.25, |
| 23 | + subcluster_col_name: str = 'subcluster' |
| 24 | +) -> None: |
22 | 25 | """ |
23 | | - For each original cluster in `adata.obs[cluster_column]`, recompute highly variable genes |
24 | | - (from the 'counts' layer, flavor='seurat_v3'), run PCA, build kNN, and re-cluster with Leiden. |
25 | | - Writes a new .obs column 'subcluster' whose labels are prefixed with the original cluster. Assumes that original counts are stored in layer['counts'] |
| 26 | + Subcluster a specified cluster (or all clusters) within an AnnData object by recomputing HVGs, PCA, |
| 27 | + kNN graph, and Leiden clustering. Updates the AnnData object in-place, adding or updating |
| 28 | + the `subcluster_col_name` column in `.obs` with new labels prefixed by the original cluster. |
26 | 29 | |
27 | | - Parameters |
28 | | - ---------- |
29 | | - adata |
30 | | - Input AnnData with a pre-existing clustering in `adata.obs[cluster_column]`. |
31 | | - cluster_column |
32 | | - `.obs` column name holding the original cluster assignment. |
33 | | - n_hvg |
34 | | - Number of highly variable genes per original cluster. |
35 | | - n_pcs |
36 | | - Number of PCs to compute. |
37 | | - n_neighbors |
38 | | - Number of neighbors for kNN graph. |
39 | | - leiden_resolution |
40 | | - Resolution parameter passed to `sc.tl.leiden`. |
| 30 | + Args: |
| 31 | + adata: AnnData |
| 32 | + The AnnData object containing precomputed clusters in `.obs[cluster_column]`. |
| 33 | + cluster_column: str, optional |
| 34 | + Name of the `.obs` column holding the original cluster assignments. Default is 'leiden'. |
| 35 | + cluster_name: str or None, optional |
| 36 | + Specific cluster label to subcluster. If `None`, applies to all clusters. Default is None. |
| 37 | + layer: str, optional |
| 38 | + Layer name in `adata.layers` to use for HVG detection. Default is 'counts'. |
| 39 | + n_hvg: int, optional |
| 40 | + Number of highly variable genes to select per cluster. Default is 2000. |
| 41 | + n_pcs: int, optional |
| 42 | + Number of principal components to compute. Default is 40. |
| 43 | + n_neighbors: int, optional |
| 44 | + Number of neighbors for the kNN graph. Default is 10. |
| 45 | + leiden_resolution: float, optional |
| 46 | + Resolution parameter for Leiden clustering. Default is 0.25. |
| 47 | + subcluster_col_name: str, optional |
| 48 | + Name of the `.obs` column to store subcluster labels. Default is 'subcluster'. |
41 | 49 | |
42 | | - Returns |
43 | | - ------- |
44 | | - None. |
45 | | - Populates .obs['subcluster'] |
| 50 | + Raises: |
| 51 | + ValueError: If `cluster_column` not in `adata.obs`. |
| 52 | + ValueError: If `layer` not in `adata.layers`. |
| 53 | + ValueError: If `cluster_name` is specified but not found in `adata.obs[cluster_column]`. |
46 | 54 | """ |
47 | | - # keep a copy of the original |
| 55 | + # Error checking |
| 56 | + if cluster_column not in adata.obs: |
| 57 | + raise ValueError(f"Cluster column '{cluster_column}' not found in adata.obs") |
| 58 | + if layer not in adata.layers: |
| 59 | + raise ValueError(f"Layer '{layer}' not found in adata.layers") |
| 60 | + |
| 61 | + # Convert original clusters to string |
48 | 62 | adata.obs['original_cluster'] = adata.obs[cluster_column].astype(str) |
49 | 63 |
|
50 | | - # prepare the column |
51 | | - adata.obs['subcluster'] = None |
| 64 | + # Ensure subcluster column exists |
| 65 | + adata.obs[subcluster_col_name] = "" |
52 | 66 |
|
53 | | - for orig in adata.obs['original_cluster'].unique(): |
| 67 | + # Validate cluster_name |
| 68 | + unique_clusters = adata.obs['original_cluster'].unique() |
| 69 | + if cluster_name is not None: |
| 70 | + if str(cluster_name) not in unique_clusters: |
| 71 | + raise ValueError( |
| 72 | + f"Cluster '{cluster_name}' not found in adata.obs['{cluster_column}']" |
| 73 | + ) |
| 74 | + clusters_to_process = [str(cluster_name)] |
| 75 | + else: |
| 76 | + clusters_to_process = unique_clusters |
| 77 | + |
| 78 | + # Iterate and subcluster |
| 79 | + for orig in clusters_to_process: |
54 | 80 | mask = adata.obs['original_cluster'] == orig |
55 | 81 | sub = adata[mask].copy() |
56 | 82 |
|
57 | | - # 1) HVG |
| 83 | + # 1) Compute HVGs |
58 | 84 | sc.pp.highly_variable_genes( |
59 | 85 | sub, |
60 | 86 | flavor='seurat_v3', |
61 | 87 | n_top_genes=n_hvg, |
62 | | - layer='counts' |
| 88 | + layer=layer |
63 | 89 | ) |
| 90 | + |
64 | 91 | # 2) PCA |
65 | 92 | sc.pp.pca(sub, n_comps=n_pcs, use_highly_variable=True) |
| 93 | + |
66 | 94 | # 3) kNN |
67 | 95 | sc.pp.neighbors(sub, n_neighbors=n_neighbors, use_rep='X_pca') |
| 96 | + |
68 | 97 | # 4) Leiden |
69 | 98 | sc.tl.leiden( |
70 | 99 | sub, |
71 | 100 | resolution=leiden_resolution, |
72 | 101 | flavor='igraph', |
73 | | - n_iterations=2 |
| 102 | + n_iterations=2, |
| 103 | + key_added='leiden_sub' |
74 | 104 | ) |
75 | 105 |
|
76 | | - # build labels like "2_0", "2_1", etc. |
77 | | - prefixed = orig + "_" + sub.obs['leiden'].astype(str) |
78 | | - adata.obs.loc[mask, 'subcluster'] = prefixed.values |
79 | | - |
80 | | - return adata |
81 | | - |
| 106 | + # Prefix and assign back |
| 107 | + labels = (orig + "_" + sub.obs['leiden_sub'].astype(str)).values |
| 108 | + adata.obs.loc[mask, subcluster_col_name] = labels |
82 | 109 |
|
83 | 110 |
|
84 | 111 |
|
|
0 commit comments