Skip to content

Commit b5498b9

Browse files
committed
[Fix] Use CRC layout instead of COO
1 parent 45eff9e commit b5498b9

1 file changed

Lines changed: 64 additions & 8 deletions

File tree

spm/__wrapper__.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _to_runtime(cls, obj):
344344
return obj.tolist()
345345
return obj
346346

347-
elif sparse and isinstance(obj, sparse.coo_array):
347+
elif sparse and isinstance(obj, sparse.sparray):
348348
return SparseArray.from_any(obj)._as_runtime()
349349

350350
else:
@@ -1328,11 +1328,9 @@ def _from_runtime(cls, dictobj: dict) -> "SparseArray":
13281328
size = np.array(dictobj['size__'], dtype=np.uint64).ravel()
13291329
size = size.tolist()
13301330
dtype = _matlab_array_types()[type(dictobj['values__'])]
1331-
obj = cls.from_shape(size, dtype=dtype)
13321331
indices = np.asarray(dictobj['indices__'], dtype=np.long) - 1
13331332
values = np.asarray(dictobj['values__'], dtype=dtype).ravel()
1334-
obj[tuple(indices.T)] = values
1335-
return obj
1333+
return cls.from_coo(values, indices.T, size)
13361334

13371335

13381336
if sparse:
@@ -1341,9 +1339,9 @@ class WrappedSparseArray(sparse.sparray, AnyWrappedArray):
13411339
"""Base class for sparse arrays."""
13421340

13431341
def to_dense(self) -> "Array":
1344-
return Array.from_any(super().to_dense())
1342+
return Array.from_any(self.todense())
13451343

1346-
class SparseArray(sparse.coo_array, _SparseMixin, WrappedSparseArray):
1344+
class SparseArray(sparse.csc_array, _SparseMixin, WrappedSparseArray):
13471345
"""
13481346
Matlab sparse arrays (scipy.sparse backend).
13491347
@@ -1371,14 +1369,40 @@ class SparseArray(sparse.coo_array, _SparseMixin, WrappedSparseArray):
13711369
def __init__(self, *args, **kwargs) -> None:
13721370
mode, arg, kwargs = self._parse_args(*args, **kwargs)
13731371
if mode == "shape":
1374-
return super().__init__(shape=arg, **kwargs)
1372+
ndim = len(arg)
1373+
return super().__init__(([], [[]]*ndim), shape=arg, **kwargs)
13751374
else:
13761375
if not isinstance(arg, (np.ndarray, sparse.sparray)):
13771376
arg = np.asanyarray(arg)
13781377
return super().__init__(arg, **kwargs)
13791378

13801379
@classmethod
1381-
def from_shape(cls, shape=tuple(), **kwargs) -> "Array":
1380+
def from_coo(cls, values, indices, shape=None, **kw) -> "SparseArray":
1381+
"""
1382+
Build a sparse array from indices and values.
1383+
1384+
Parameters
1385+
----------
1386+
values : (N,) ArrayLike
1387+
Values to set at each index.
1388+
indices : (D, N) ArrayLike
1389+
Indices of nonzero elements.
1390+
shape : list[int] | None
1391+
Shape of the array.
1392+
dtype : np.dtype | None
1393+
Target data type. Same as `values` by default.
1394+
1395+
Returns
1396+
-------
1397+
array : SparseArray
1398+
New array.
1399+
"""
1400+
indices = np.asarray(indices)
1401+
coo = sparse.coo_array((values, indices), shape=shape, **kw)
1402+
return cls.from_any(coo)
1403+
1404+
@classmethod
1405+
def from_shape(cls, shape=tuple(), **kwargs) -> "SparseArray":
13821406
"""
13831407
Build an array of a given shape.
13841408
@@ -1472,6 +1496,38 @@ class SparseArray(_SparseMixin, Array):
14721496
def to_dense(self) -> "Array":
14731497
return np.ndarray.view(self, Array)
14741498

1499+
@classmethod
1500+
def from_coo(cls, values, indices, shape=None, **kw) -> "SparseArray":
1501+
"""
1502+
Build a sparse array from indices and values.
1503+
1504+
Parameters
1505+
----------
1506+
values : (N,) ArrayLike
1507+
Values to set at each index.
1508+
indices : (D, N) ArrayLike
1509+
Indices of nonzero elements.
1510+
shape : list[int] | None
1511+
Shape of the array.
1512+
dtype : np.dtype | None
1513+
Target data type. Same as `values` by default.
1514+
1515+
Returns
1516+
-------
1517+
array : SparseArray
1518+
New array.
1519+
"""
1520+
dtype = kw.get("dtype", None)
1521+
indices = np.asarray(indices)
1522+
values = np.asarray(values, dtype=dtype)
1523+
if shape is None:
1524+
shape = (1 + indices.max(-1)).astype(np.uint64).tolist()
1525+
if dtype is None:
1526+
dtype = values.dtype
1527+
obj = cls.from_shape(shape, dtype=dtype)
1528+
obj[tuple(indices)] = values
1529+
return obj
1530+
14751531

14761532
# ----------------------------------------------------------------------
14771533
# Cell

0 commit comments

Comments
 (0)