@@ -595,11 +595,29 @@ def your_function(x, y):
595595# backwards compatibility alias
596596get_namespace = array_namespace
597597
598- def _check_device (xp , device ):
599- if xp == sys .modules .get ('numpy' ):
600- if device not in ["cpu" , None ]:
598+
599+ def _check_device (bare_xp , device ):
600+ """
601+ Validate dummy device on device-less array backends.
602+
603+ Notes
604+ -----
605+ This function is also invoked by CuPy, which does have multiple devices
606+ if there are multiple GPUs available.
607+ However, CuPy multi-device support is currently impossible
608+ without using the global device or a context manager:
609+
610+ https://github.com/data-apis/array-api-compat/pull/293
611+ """
612+ if bare_xp is sys .modules .get ('numpy' ):
613+ if device not in ("cpu" , None ):
601614 raise ValueError (f"Unsupported device for NumPy: { device !r} " )
602615
616+ elif bare_xp is sys .modules .get ('dask.array' ):
617+ if device not in ("cpu" , _DASK_DEVICE , None ):
618+ raise ValueError (f"Unsupported device for Dask: { device !r} " )
619+
620+
603621# Placeholder object to represent the dask device
604622# when the array backend is not the CPU.
605623# (since it is not easy to tell which device a dask array is on)
0 commit comments