@@ -263,6 +263,22 @@ def _get_validated_view(tensor):
263263 return view
264264
265265
266+ cdef inline intptr_t _get_current_context_ptr() except ? 0 :
267+ cdef cydriver.CUcontext ctx
268+ with nogil:
269+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
270+ if ctx == NULL :
271+ raise RuntimeError (" TensorMapDescriptor requires an active CUDA context" )
272+ return < intptr_t> ctx
273+
274+
275+ cdef inline int _get_current_device_id() except - 1 :
276+ cdef cydriver.CUdevice dev
277+ with nogil:
278+ HANDLE_RETURN(cydriver.cuCtxGetDevice(& dev))
279+ return < int > dev
280+
281+
266282def _compute_byte_strides (shape , strides , elem_size ):
267283 """ Compute byte strides from element strides or C-contiguous fallback.
268284
@@ -313,6 +329,28 @@ cdef class TensorMapDescriptor:
313329 cdef void * _get_data_ptr(self ):
314330 return < void * > & self ._tensor_map
315331
332+ cdef int _check_context_compat(self ) except - 1 :
333+ cdef cydriver.CUcontext current_ctx
334+ cdef cydriver.CUdevice current_dev
335+ cdef int current_dev_id
336+ if self ._context == 0 and self ._device_id < 0 :
337+ return 0
338+ with nogil:
339+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(& current_ctx))
340+ if current_ctx == NULL :
341+ raise RuntimeError (" TensorMapDescriptor requires an active CUDA context" )
342+ if self ._context != 0 and < intptr_t> current_ctx != self ._context:
343+ raise RuntimeError (
344+ " TensorMapDescriptor was created in a different CUDA context" )
345+ with nogil:
346+ HANDLE_RETURN(cydriver.cuCtxGetDevice(& current_dev))
347+ current_dev_id = < int > current_dev
348+ if self ._device_id >= 0 and current_dev_id != self ._device_id:
349+ raise RuntimeError (
350+ f" TensorMapDescriptor belongs to device {self._device_id}, "
351+ f" but current device is {current_dev_id}" )
352+ return 0
353+
316354 @classmethod
317355 def from_tiled (cls , tensor , box_dim , *,
318356 element_strides = None ,
@@ -366,6 +404,8 @@ cdef class TensorMapDescriptor:
366404 # deleter can free the backing allocation when released.
367405 desc._source_ref = tensor
368406 desc._view_ref = view
407+ desc._context = _get_current_context_ptr()
408+ desc._device_id = _get_current_device_id()
369409
370410 tma_dt = _resolve_data_type(view, data_type)
371411 cdef int c_data_type_int = int (tma_dt)
@@ -593,6 +633,8 @@ cdef class TensorMapDescriptor:
593633 view = _get_validated_view(tensor)
594634 desc._source_ref = tensor
595635 desc._view_ref = view
636+ desc._context = _get_current_context_ptr()
637+ desc._device_id = _get_current_device_id()
596638
597639 tma_dt = _resolve_data_type(view, data_type)
598640 cdef int c_data_type_int = int (tma_dt)
@@ -750,6 +792,8 @@ cdef class TensorMapDescriptor:
750792 view = _get_validated_view(tensor)
751793 desc._source_ref = tensor
752794 desc._view_ref = view
795+ desc._context = _get_current_context_ptr()
796+ desc._device_id = _get_current_device_id()
753797
754798 tma_dt = _resolve_data_type(view, data_type)
755799 cdef int c_data_type_int = int (tma_dt)
@@ -839,7 +883,11 @@ cdef class TensorMapDescriptor:
839883 or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
840884 device-accessible memory with a 16-byte-aligned pointer.
841885 """
886+ self ._check_context_compat()
842887 view = _get_validated_view(tensor)
888+ if view.device_id != self ._device_id:
889+ raise ValueError (
890+ f" replace_address expects tensor on device {self._device_id}, got {view.device_id}" )
843891
844892 cdef intptr_t global_address = view.ptr
845893
0 commit comments