1414import ome_types
1515import packaging .version
1616import pandas as pd
17+ import pyarrow .compute as pc
1718import pyarrow .parquet as pq
1819import tifffile
1920import zarr
2021from dask .dataframe import read_parquet
2122from dask_image .imread import imread
2223from geopandas import GeoDataFrame
23- from pyarrow import Table
2424from shapely import GeometryType , Polygon , from_ragged_array
2525from spatialdata import SpatialData
2626from spatialdata ._core .query .relational_query import get_element_instances
4444if TYPE_CHECKING :
4545 from collections .abc import Mapping
4646
47+ import pyarrow as pa
4748 from anndata import AnnData
4849 from spatialdata ._types import ArrayLike
4950
@@ -210,7 +211,11 @@ def xenium(
210211 if version is not None and version >= packaging .version .parse ("2.0.0" ) and table is not None :
211212 assert cells_zarr is not None
212213 cell_summary_table = _get_cells_metadata_table_from_zarr (cells_zarr , specs , cells_zarr_cell_id_str )
213- if not np .array_equal (cell_summary_table [XeniumKeys .CELL_ID ].values , table .obs [XeniumKeys .CELL_ID ].values ):
214+ try :
215+ _assert_arrays_equal_sampled (
216+ cell_summary_table [XeniumKeys .CELL_ID ].values , table .obs [XeniumKeys .CELL_ID ].values
217+ )
218+ except AssertionError :
214219 warnings .warn (
215220 'The "cell_id" column in the cells metadata table does not match the "cell_id" column in the annotation'
216221 " table. This could be due to trying to read a new version that is not supported yet. Please "
@@ -254,9 +259,11 @@ def xenium(
254259 cell_id_str = cells_zarr_cell_id_str ,
255260 )
256261 if cell_labels_indices_mapping is not None and table is not None :
257- if not np .array_equal (
258- cell_labels_indices_mapping ["cell_id" ].values , table .obs [str (XeniumKeys .CELL_ID )].values
259- ):
262+ try :
263+ _assert_arrays_equal_sampled (
264+ cell_labels_indices_mapping ["cell_id" ].values , table .obs [str (XeniumKeys .CELL_ID )].values
265+ )
266+ except AssertionError :
260267 warnings .warn (
261268 "The cell_id column in the cell_labels_table does not match the cell_id column derived from the "
262269 "cell labels data. This could be due to trying to read a new version that is not supported yet. "
@@ -274,7 +281,7 @@ def xenium(
274281 path ,
275282 XeniumKeys .NUCLEUS_BOUNDARIES_FILE ,
276283 specs ,
277- idx = table . obs [ str ( XeniumKeys . CELL_ID )]. copy () ,
284+ idx = None ,
278285 )
279286
280287 if cells_boundaries :
@@ -415,6 +422,13 @@ def filter(self, record: logging.LogRecord) -> bool:
415422 return _set_reader_metadata (sdata , "xenium" )
416423
417424
425+ def _assert_arrays_equal_sampled (a : ArrayLike , b : ArrayLike , n : int = 100 ) -> None :
426+ """Assert two arrays are equal by checking a random sample of entries."""
427+ assert len (a ) == len (b ), f"Array lengths differ: { len (a )} != { len (b )} "
428+ idx = np .random .default_rng (0 ).choice (len (a ), size = min (n , len (a )), replace = False )
429+ np .testing .assert_array_equal (np .asarray (a [idx ]), np .asarray (b [idx ]))
430+
431+
418432def _decode_cell_id_column (cell_id_column : pd .Series ) -> pd .Series :
419433 if isinstance (cell_id_column .iloc [0 ], bytes ):
420434 return cell_id_column .str .decode ("utf-8" )
@@ -429,28 +443,35 @@ def _get_polygons(
429443 specs : dict [str , Any ],
430444 idx : pd .Series | None = None ,
431445) -> GeoDataFrame :
432- # seems to be faster than pd.read_parquet
433- df = pq .read_table (path / file ).to_pandas ()
434- cell_ids = df [XeniumKeys .CELL_ID ].to_numpy ()
435- x = df [XeniumKeys .BOUNDARIES_VERTEX_X ].to_numpy ()
436- y = df [XeniumKeys .BOUNDARIES_VERTEX_Y ].to_numpy ()
446+ # Use PyArrow compute to avoid slow .to_numpy() on Arrow-backed strings in pandas >= 3.0
447+ # The original approach was:
448+ # df = pq.read_table(path / file).to_pandas()
449+ # cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
450+ # which got slow with pandas >= 3.0 (Arrow-backed string .to_numpy() is ~100x slower).
451+ # By doing change detection in Arrow, we avoid allocating Python string objects for all rows.
452+ table = pq .read_table (path / file )
453+ cell_id_col = table .column (str (XeniumKeys .CELL_ID ))
454+
455+ x = table .column (str (XeniumKeys .BOUNDARIES_VERTEX_X )).to_numpy ()
456+ y = table .column (str (XeniumKeys .BOUNDARIES_VERTEX_Y )).to_numpy ()
437457 coords = np .column_stack ([x , y ])
438458
439- change_mask = np .concatenate ([[True ], cell_ids [1 :] != cell_ids [:- 1 ]])
459+ n = len (cell_id_col )
460+ change_mask = np .empty (n , dtype = bool )
461+ change_mask [0 ] = True
462+ change_mask [1 :] = pc .not_equal (cell_id_col .slice (0 , n - 1 ), cell_id_col .slice (1 )).to_numpy (zero_copy_only = False )
440463 group_starts = np .where (change_mask )[0 ]
441- group_ends = np .concatenate ([group_starts [1 :], [len ( cell_ids ) ]])
464+ group_ends = np .concatenate ([group_starts [1 :], [n ]])
442465
443466 # sanity check
444- n_unique_ids = len ( df [ XeniumKeys . CELL_ID ]. drop_duplicates () )
467+ n_unique_ids = pc . count_distinct ( cell_id_col ). as_py ( )
445468 if len (group_starts ) != n_unique_ids :
446469 raise ValueError (
447470 f"In { file } , rows belonging to the same polygon must be contiguous. "
448471 f"Expected { n_unique_ids } group starts, but found { len (group_starts )} . "
449472 f"This indicates non-consecutive polygon rows."
450473 )
451474
452- unique_ids = cell_ids [group_starts ]
453-
454475 # offsets for ragged array:
455476 # offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
456477 # offsets[1] (geom_offsets): describing to which polygons the rings belong to
@@ -459,22 +480,16 @@ def _get_polygons(
459480
460481 geoms = from_ragged_array (GeometryType .POLYGON , coords , offsets = (ring_offsets , geom_offsets ))
461482
462- index = _decode_cell_id_column (pd .Series (unique_ids ))
463- geo_df = GeoDataFrame ({"geometry" : geoms }, index = index .values )
464-
465- version = _parse_version_of_xenium_analyzer (specs )
466- if version is not None and version < packaging .version .parse ("2.0.0" ):
467- assert idx is not None
468- assert len (idx ) == len (geo_df )
469- assert np .array_equal (index .values , idx .values )
483+ # idx is not None for the cells and None for the nuclei (for xenium(cells_table=False) is None for both
484+ if idx is not None :
485+ # Cell IDs already available from the annotation table
486+ assert len (idx ) == len (group_starts ), f"Expected { len (group_starts )} cell IDs, got { len (idx )} "
487+ geo_df = GeoDataFrame ({"geometry" : geoms }, index = idx .values )
470488 else :
471- if np .unique (geo_df .index ).size != len (geo_df ):
472- warnings .warn (
473- "Found non-unique polygon indices, this will be addressed in a future version of the reader. For the "
474- "time being please consider merging polygons with non-unique indices into single multi-polygons." ,
475- UserWarning ,
476- stacklevel = 2 ,
477- )
489+ # Fall back to extracting unique cell IDs from parquet (slow for large_string columns).
490+ unique_ids = cell_id_col .filter (change_mask ).to_pylist ()
491+ index = _decode_cell_id_column (pd .Series (unique_ids ))
492+ geo_df = GeoDataFrame ({"geometry" : geoms }, index = index .values )
478493
479494 scale = Scale ([1.0 / specs ["pixel_size" ], 1.0 / specs ["pixel_size" ]], axes = ("x" , "y" ))
480495 return ShapesModel .parse (geo_df , transformations = {"global" : scale })
@@ -559,7 +574,7 @@ def _get_cells_metadata_table_from_zarr(
559574 return df
560575
561576
562- def _get_points (path : Path , specs : dict [str , Any ]) -> Table :
577+ def _get_points (path : Path , specs : dict [str , Any ]) -> pa . Table :
563578 table = read_parquet (path / XeniumKeys .TRANSCRIPTS_FILE )
564579
565580 # check if we need to decode bytes
@@ -601,7 +616,7 @@ def _get_tables_and_circles(
601616) -> AnnData | tuple [AnnData , AnnData ]:
602617 adata = _read_10x_h5 (path / XeniumKeys .CELL_FEATURE_MATRIX_FILE , gex_only = gex_only )
603618 metadata = pd .read_parquet (path / XeniumKeys .CELL_METADATA_FILE )
604- np . testing . assert_array_equal (metadata .cell_id .astype (str ), adata .obs_names .values )
619+ _assert_arrays_equal_sampled (metadata .cell_id .astype (str ), adata .obs_names .values )
605620 circ = metadata [[XeniumKeys .CELL_X , XeniumKeys .CELL_Y ]].to_numpy ()
606621 adata .obsm ["spatial" ] = circ
607622 metadata .drop ([XeniumKeys .CELL_X , XeniumKeys .CELL_Y ], axis = 1 , inplace = True )
0 commit comments