Skip to content

Commit 473b2ae

Browse files
committed
improve _read_polygons slow on pandas 3 due to string conversion
1 parent 494d801 commit 473b2ae

1 file changed

Lines changed: 48 additions & 33 deletions

File tree

src/spatialdata_io/readers/xenium.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import ome_types
1515
import packaging.version
1616
import pandas as pd
17+
import pyarrow.compute as pc
1718
import pyarrow.parquet as pq
1819
import tifffile
1920
import zarr
2021
from dask.dataframe import read_parquet
2122
from dask_image.imread import imread
2223
from geopandas import GeoDataFrame
23-
from pyarrow import Table
2424
from shapely import GeometryType, Polygon, from_ragged_array
2525
from spatialdata import SpatialData
2626
from spatialdata._core.query.relational_query import get_element_instances
@@ -44,6 +44,7 @@
4444
if TYPE_CHECKING:
4545
from collections.abc import Mapping
4646

47+
import pyarrow as pa
4748
from anndata import AnnData
4849
from spatialdata._types import ArrayLike
4950

@@ -210,7 +211,11 @@ def xenium(
210211
if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
211212
assert cells_zarr is not None
212213
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, specs, cells_zarr_cell_id_str)
213-
if not np.array_equal(cell_summary_table[XeniumKeys.CELL_ID].values, table.obs[XeniumKeys.CELL_ID].values):
214+
try:
215+
_assert_arrays_equal_sampled(
216+
cell_summary_table[XeniumKeys.CELL_ID].values, table.obs[XeniumKeys.CELL_ID].values
217+
)
218+
except AssertionError:
214219
warnings.warn(
215220
'The "cell_id" column in the cells metadata table does not match the "cell_id" column in the annotation'
216221
" table. This could be due to trying to read a new version that is not supported yet. Please "
@@ -254,9 +259,11 @@ def xenium(
254259
cell_id_str=cells_zarr_cell_id_str,
255260
)
256261
if cell_labels_indices_mapping is not None and table is not None:
257-
if not np.array_equal(
258-
cell_labels_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
259-
):
262+
try:
263+
_assert_arrays_equal_sampled(
264+
cell_labels_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
265+
)
266+
except AssertionError:
260267
warnings.warn(
261268
"The cell_id column in the cell_labels_table does not match the cell_id column derived from the "
262269
"cell labels data. This could be due to trying to read a new version that is not supported yet. "
@@ -274,7 +281,7 @@ def xenium(
274281
path,
275282
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
276283
specs,
277-
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
284+
idx=None,
278285
)
279286

280287
if cells_boundaries:
@@ -415,6 +422,13 @@ def filter(self, record: logging.LogRecord) -> bool:
415422
return _set_reader_metadata(sdata, "xenium")
416423

417424

425+
def _assert_arrays_equal_sampled(a: ArrayLike, b: ArrayLike, n: int = 100) -> None:
426+
"""Assert two arrays are equal by checking a random sample of entries."""
427+
assert len(a) == len(b), f"Array lengths differ: {len(a)} != {len(b)}"
428+
idx = np.random.default_rng(0).choice(len(a), size=min(n, len(a)), replace=False)
429+
np.testing.assert_array_equal(np.asarray(a[idx]), np.asarray(b[idx]))
430+
431+
418432
def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:
419433
if isinstance(cell_id_column.iloc[0], bytes):
420434
return cell_id_column.str.decode("utf-8")
@@ -429,28 +443,35 @@ def _get_polygons(
429443
specs: dict[str, Any],
430444
idx: pd.Series | None = None,
431445
) -> GeoDataFrame:
432-
# seems to be faster than pd.read_parquet
433-
df = pq.read_table(path / file).to_pandas()
434-
cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
435-
x = df[XeniumKeys.BOUNDARIES_VERTEX_X].to_numpy()
436-
y = df[XeniumKeys.BOUNDARIES_VERTEX_Y].to_numpy()
446+
# Use PyArrow compute to avoid slow .to_numpy() on Arrow-backed strings in pandas >= 3.0
447+
# The original approach was:
448+
# df = pq.read_table(path / file).to_pandas()
449+
# cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
450+
# which got slow with pandas >= 3.0 (Arrow-backed string .to_numpy() is ~100x slower).
451+
# By doing change detection in Arrow, we avoid allocating Python string objects for all rows.
452+
table = pq.read_table(path / file)
453+
cell_id_col = table.column(str(XeniumKeys.CELL_ID))
454+
455+
x = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_X)).to_numpy()
456+
y = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_Y)).to_numpy()
437457
coords = np.column_stack([x, y])
438458

439-
change_mask = np.concatenate([[True], cell_ids[1:] != cell_ids[:-1]])
459+
n = len(cell_id_col)
460+
change_mask = np.empty(n, dtype=bool)
461+
change_mask[0] = True
462+
change_mask[1:] = pc.not_equal(cell_id_col.slice(0, n - 1), cell_id_col.slice(1)).to_numpy(zero_copy_only=False)
440463
group_starts = np.where(change_mask)[0]
441-
group_ends = np.concatenate([group_starts[1:], [len(cell_ids)]])
464+
group_ends = np.concatenate([group_starts[1:], [n]])
442465

443466
# sanity check
444-
n_unique_ids = len(df[XeniumKeys.CELL_ID].drop_duplicates())
467+
n_unique_ids = pc.count_distinct(cell_id_col).as_py()
445468
if len(group_starts) != n_unique_ids:
446469
raise ValueError(
447470
f"In {file}, rows belonging to the same polygon must be contiguous. "
448471
f"Expected {n_unique_ids} group starts, but found {len(group_starts)}. "
449472
f"This indicates non-consecutive polygon rows."
450473
)
451474

452-
unique_ids = cell_ids[group_starts]
453-
454475
# offsets for ragged array:
455476
# offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
456477
# offsets[1] (geom_offsets): describing to which polygons the rings belong to
@@ -459,22 +480,16 @@ def _get_polygons(
459480

460481
geoms = from_ragged_array(GeometryType.POLYGON, coords, offsets=(ring_offsets, geom_offsets))
461482

462-
index = _decode_cell_id_column(pd.Series(unique_ids))
463-
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)
464-
465-
version = _parse_version_of_xenium_analyzer(specs)
466-
if version is not None and version < packaging.version.parse("2.0.0"):
467-
assert idx is not None
468-
assert len(idx) == len(geo_df)
469-
assert np.array_equal(index.values, idx.values)
483+
# idx is not None for the cells and None for the nuclei (for xenium(cells_table=False) is None for both
484+
if idx is not None:
485+
# Cell IDs already available from the annotation table
486+
assert len(idx) == len(group_starts), f"Expected {len(group_starts)} cell IDs, got {len(idx)}"
487+
geo_df = GeoDataFrame({"geometry": geoms}, index=idx.values)
470488
else:
471-
if np.unique(geo_df.index).size != len(geo_df):
472-
warnings.warn(
473-
"Found non-unique polygon indices, this will be addressed in a future version of the reader. For the "
474-
"time being please consider merging polygons with non-unique indices into single multi-polygons.",
475-
UserWarning,
476-
stacklevel=2,
477-
)
489+
# Fall back to extracting unique cell IDs from parquet (slow for large_string columns).
490+
unique_ids = cell_id_col.filter(change_mask).to_pylist()
491+
index = _decode_cell_id_column(pd.Series(unique_ids))
492+
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)
478493

479494
scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
480495
return ShapesModel.parse(geo_df, transformations={"global": scale})
@@ -559,7 +574,7 @@ def _get_cells_metadata_table_from_zarr(
559574
return df
560575

561576

562-
def _get_points(path: Path, specs: dict[str, Any]) -> Table:
577+
def _get_points(path: Path, specs: dict[str, Any]) -> pa.Table:
563578
table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE)
564579

565580
# check if we need to decode bytes
@@ -601,7 +616,7 @@ def _get_tables_and_circles(
601616
) -> AnnData | tuple[AnnData, AnnData]:
602617
adata = _read_10x_h5(path / XeniumKeys.CELL_FEATURE_MATRIX_FILE, gex_only=gex_only)
603618
metadata = pd.read_parquet(path / XeniumKeys.CELL_METADATA_FILE)
604-
np.testing.assert_array_equal(metadata.cell_id.astype(str), adata.obs_names.values)
619+
_assert_arrays_equal_sampled(metadata.cell_id.astype(str), adata.obs_names.values)
605620
circ = metadata[[XeniumKeys.CELL_X, XeniumKeys.CELL_Y]].to_numpy()
606621
adata.obsm["spatial"] = circ
607622
metadata.drop([XeniumKeys.CELL_X, XeniumKeys.CELL_Y], axis=1, inplace=True)

0 commit comments

Comments
 (0)