@@ -45,6 +45,8 @@ from ._backend cimport ( # noqa: E211
4545 _device_type,
4646)
4747
48+ from contextvars import ContextVar
49+
4850from ._sycl_device import SyclDeviceCreationError
4951from .enum_types import backend_type
5052from .enum_types import device_type as device_type_t
@@ -59,6 +61,7 @@ __all__ = [
5961 " has_cpu_devices" ,
6062 " has_gpu_devices" ,
6163 " has_accelerator_devices" ,
64+ " _cached_default_device" ,
6265]
6366
6467
@@ -355,3 +358,48 @@ cpdef SyclDevice select_gpu_device():
355358 raise SyclDeviceCreationError(" Device unavailable." )
356359 Device = SyclDevice._create(DRef)
357360 return Device
361+
362+
363+ cdef class _DefaultDeviceCache:
364+ cdef dict __device_map__
365+
366+ def __cinit__ (self ):
367+ self .__device_map__ = dict ()
368+
369+ cdef get_or_create(self ):
370+ """ Return instance of SyclDevice and indicator if cache
371+ has been modified"""
372+ key = 0
373+ if key in self .__device_map__:
374+ return self .__device_map__[key], False
375+ dev = select_default_device()
376+ self .__device_map__[key] = dev
377+ return dev, True
378+
379+ cdef _update_map(self , dev_map):
380+ self .__device_map__.update(dev_map)
381+
382+ def __copy__ (self ):
383+ cdef _DefaultDeviceCache _copy = _DefaultDeviceCache.__new__ (
384+ _DefaultDeviceCache)
385+ _copy._update_map(self .__device_map__)
386+ return _copy
387+
388+
389+ _global_default_device_cache = ContextVar(
390+ ' global_default_device_cache' ,
391+ default = _DefaultDeviceCache()
392+ )
393+
394+
395+ cpdef SyclDevice _cached_default_device():
396+ """ Returns a cached devide selected by default selector.
397+
398+ Returns:
399+ :class:`dpctl.SyclDevice`: A cached default-selected SYCL device.
400+
401+ """
402+ cdef _DefaultDeviceCache _cache = _global_default_device_cache.get()
403+ d_, changed_ = _cache.get_or_create()
404+ if changed_: _global_default_device_cache.set(_cache)
405+ return d_
0 commit comments