-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path_patches.py
More file actions
26 lines (20 loc) · 930 Bytes
/
_patches.py
File metadata and controls
26 lines (20 loc) · 930 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
import numpy as np
# TODO(flying-sheep): upstream
# https://github.com/dask/dask/issues/11749
def patch_dask() -> None: # pragma: no cover
"""Patch dask to support sparse arrays.
See <https://github.com/dask/dask/blob/4d71629d1f22ced0dd780919f22e70a642ec6753/dask/array/backends.py#L212-L232>
"""
try:
# Other lookup candidates: tensordot_lookup and take_lookup
from dask.array.dispatch import concatenate_lookup
from scipy.sparse import sparray, spmatrix
except ImportError:
return # No need to patch if dask or scipy is not installed
# Avoid patch if already patched or upstream support has been added
if concatenate_lookup.dispatch(sparray) is not np.concatenate:
return
concatenate = concatenate_lookup.dispatch(spmatrix)
concatenate_lookup.register(sparray, concatenate)