Skip to content

Commit eef1c7a

Browse files
committed
Handle TensorMap device validation by DLPack type
Reject CUDA device-local tensors from a different GPU while still allowing CUDA host and managed memory. Add regression tests for descriptor creation, replace_address, and the shared validation helper.
1 parent 5a0e141 commit eef1c7a

2 files changed

Lines changed: 145 additions & 5 deletions

File tree

cuda_core/cuda/core/_tensor_map.pyx

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,19 @@ def _get_validated_view(tensor):
263263
return view
264264

265265

266+
def _require_view_device(view, expected_device_id, operation):
267+
"""Ensure device-local tensors match the current CUDA device.
268+
269+
DLPack reports host/managed CUDA memory as ``kDLCUDAHost`` /
270+
``kDLCUDAManaged`` with ``device_id=0`` regardless of the current device,
271+
so only true ``kDLCUDA`` tensors are rejected by device-id mismatch.
272+
"""
273+
device_type, device_id = view.__dlpack_device__()
274+
if device_type == _kDLCUDA and device_id != expected_device_id:
275+
raise ValueError(
276+
f"{operation} expects tensor on device {expected_device_id}, got {device_id}")
277+
278+
266279
cdef inline intptr_t _get_current_context_ptr() except? 0:
267280
cdef cydriver.CUcontext ctx
268281
with nogil:
@@ -406,6 +419,7 @@ cdef class TensorMapDescriptor:
406419
desc._view_ref = view
407420
desc._context = _get_current_context_ptr()
408421
desc._device_id = _get_current_device_id()
422+
_require_view_device(view, desc._device_id, "TensorMapDescriptor.from_tiled")
409423

410424
tma_dt = _resolve_data_type(view, data_type)
411425
cdef int c_data_type_int = int(tma_dt)
@@ -447,6 +461,8 @@ cdef class TensorMapDescriptor:
447461
cdef int i_cccl
448462
cdef int device_type
449463
cdef int c_device_id
464+
cdef int dl_device_type
465+
cdef int dl_device_id
450466
cdef int c_cccl_interleave_int
451467
cdef int c_cccl_swizzle_int
452468
cdef int c_cccl_l2_promotion_int
@@ -471,8 +487,9 @@ cdef class TensorMapDescriptor:
471487
if elem_strides_provided:
472488
c_elem_strides_ptr = &c_elem_strides[0]
473489

474-
device_type = <int>_kDLCUDA
475-
c_device_id = <int>view.device_id
490+
dl_device_type, dl_device_id = view.__dlpack_device__()
491+
device_type = dl_device_type
492+
c_device_id = dl_device_id
476493
c_cccl_interleave_int = int(interleave)
477494
c_cccl_swizzle_int = int(swizzle)
478495
c_cccl_l2_promotion_int = int(l2_promotion)
@@ -635,6 +652,7 @@ cdef class TensorMapDescriptor:
635652
desc._view_ref = view
636653
desc._context = _get_current_context_ptr()
637654
desc._device_id = _get_current_device_id()
655+
_require_view_device(view, desc._device_id, "TensorMapDescriptor.from_im2col")
638656

639657
tma_dt = _resolve_data_type(view, data_type)
640658
cdef int c_data_type_int = int(tma_dt)
@@ -794,6 +812,7 @@ cdef class TensorMapDescriptor:
794812
desc._view_ref = view
795813
desc._context = _get_current_context_ptr()
796814
desc._device_id = _get_current_device_id()
815+
_require_view_device(view, desc._device_id, "TensorMapDescriptor.from_im2col_wide")
797816

798817
tma_dt = _resolve_data_type(view, data_type)
799818
cdef int c_data_type_int = int(tma_dt)
@@ -885,9 +904,7 @@ cdef class TensorMapDescriptor:
885904
"""
886905
self._check_context_compat()
887906
view = _get_validated_view(tensor)
888-
if view.device_id != self._device_id:
889-
raise ValueError(
890-
f"replace_address expects tensor on device {self._device_id}, got {view.device_id}")
907+
_require_view_device(view, self._device_id, "replace_address")
891908

892909
cdef intptr_t global_address = view.ptr
893910

cuda_core/tests/test_tensor_map.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,23 @@
44
import numpy as np
55
import pytest
66

7+
from conftest import create_managed_memory_resource_or_skip, skip_if_managed_memory_unsupported
78
from cuda.core import (
89
Device,
10+
ManagedMemoryResourceOptions,
911
StridedMemoryView,
1012
TensorMapDescriptor,
13+
system,
1114
)
15+
from cuda.core._dlpack import DLDeviceType
1216
from cuda.core._tensor_map import (
1317
TensorMapDataType,
1418
TensorMapIm2ColWideMode,
1519
TensorMapInterleave,
1620
TensorMapL2Promotion,
1721
TensorMapOOBFill,
1822
TensorMapSwizzle,
23+
_require_view_device,
1924
)
2025

2126

@@ -48,6 +53,15 @@ def __init__(self, buf, shape, dtype=np.float32):
4853
}
4954

5055

56+
class _MockTensorMapView:
57+
def __init__(self, device_type, device_id):
58+
self._device_type = device_type
59+
self._device_id = device_id
60+
61+
def __dlpack_device__(self):
62+
return (self._device_type, self._device_id)
63+
64+
5165
class TestTensorMapEnums:
5266
"""Test that enum wrappers expose the expected values."""
5367

@@ -323,6 +337,115 @@ def test_replace_address_requires_device_accessible(self, dev, skip_if_no_tma):
323337
with pytest.raises(ValueError, match="device-accessible"):
324338
desc.replace_address(host_arr)
325339

340+
def test_replace_address_rejects_tensor_from_other_device(self, dev, skip_if_no_tma):
341+
if system.get_num_devices() < 2:
342+
pytest.skip("requires multi-GPU")
343+
344+
dev0 = dev
345+
dev1 = Device(1)
346+
347+
dev0.set_current()
348+
buf0 = dev0.allocate(1024 * 4)
349+
desc = TensorMapDescriptor.from_tiled(
350+
buf0,
351+
box_dim=(64,),
352+
data_type=TensorMapDataType.FLOAT32,
353+
)
354+
355+
dev1.set_current()
356+
buf1 = dev1.allocate(1024 * 4)
357+
dev0.set_current()
358+
359+
with pytest.raises(ValueError, match=r"replace_address expects tensor on device 0, got 1"):
360+
desc.replace_address(buf1)
361+
362+
def test_replace_address_accepts_managed_buffer_on_nonzero_device(self, init_cuda):
363+
if system.get_num_devices() < 2:
364+
pytest.skip("requires multi-GPU")
365+
366+
dev1 = Device(1)
367+
if not dev1.properties.tensor_map_access_supported:
368+
pytest.skip("Device does not support TMA (requires compute capability 9.0+)")
369+
skip_if_managed_memory_unsupported(dev1)
370+
371+
dev1.set_current()
372+
desc = TensorMapDescriptor.from_tiled(
373+
dev1.allocate(1024 * 4),
374+
box_dim=(64,),
375+
data_type=TensorMapDataType.FLOAT32,
376+
)
377+
378+
mr = create_managed_memory_resource_or_skip(
379+
ManagedMemoryResourceOptions(preferred_location=dev1.device_id)
380+
)
381+
managed_buf = mr.allocate(1024 * 4)
382+
383+
desc.replace_address(managed_buf)
384+
385+
386+
class TestTensorMapMultiDeviceValidation:
387+
"""Test multi-device validation for descriptor creation."""
388+
389+
def test_from_tiled_rejects_tensor_from_other_device(self, init_cuda):
390+
if system.get_num_devices() < 2:
391+
pytest.skip("requires multi-GPU")
392+
393+
dev0 = Device(0)
394+
dev1 = Device(1)
395+
396+
dev1.set_current()
397+
buf1 = dev1.allocate(1024 * 4)
398+
dev0.set_current()
399+
400+
with pytest.raises(
401+
ValueError,
402+
match=r"TensorMapDescriptor\.from_tiled expects tensor on device 0, got 1",
403+
):
404+
TensorMapDescriptor.from_tiled(
405+
buf1,
406+
box_dim=(64,),
407+
data_type=TensorMapDataType.FLOAT32,
408+
)
409+
410+
def test_from_tiled_accepts_managed_buffer_on_nonzero_device(self, init_cuda):
411+
if system.get_num_devices() < 2:
412+
pytest.skip("requires multi-GPU")
413+
414+
dev1 = Device(1)
415+
if not dev1.properties.tensor_map_access_supported:
416+
pytest.skip("Device does not support TMA (requires compute capability 9.0+)")
417+
skip_if_managed_memory_unsupported(dev1)
418+
419+
dev1.set_current()
420+
mr = create_managed_memory_resource_or_skip(
421+
ManagedMemoryResourceOptions(preferred_location=dev1.device_id)
422+
)
423+
managed_buf = mr.allocate(1024 * 4)
424+
425+
desc = TensorMapDescriptor.from_tiled(
426+
managed_buf,
427+
box_dim=(64,),
428+
data_type=TensorMapDataType.FLOAT32,
429+
)
430+
assert desc is not None
431+
432+
433+
class TestTensorMapDeviceValidation:
434+
"""Test device validation behavior for tensor-map-compatible views."""
435+
436+
def test_require_view_device_accepts_same_cuda_device(self):
437+
_require_view_device(_MockTensorMapView(DLDeviceType.kDLCUDA, 1), 1, "op")
438+
439+
def test_require_view_device_rejects_different_cuda_device(self):
440+
with pytest.raises(ValueError, match=r"op expects tensor on device 0, got 1"):
441+
_require_view_device(_MockTensorMapView(DLDeviceType.kDLCUDA, 1), 0, "op")
442+
443+
def test_require_view_device_allows_cuda_host_memory(self):
444+
_require_view_device(_MockTensorMapView(DLDeviceType.kDLCUDAHost, 0), 1, "op")
445+
446+
def test_require_view_device_allows_cuda_managed_memory(self):
447+
_require_view_device(_MockTensorMapView(DLDeviceType.kDLCUDAManaged, 0), 1, "op")
448+
326449

327450
class TestTensorMapIm2col:
328451
"""Test im2col TMA descriptor creation."""

0 commit comments

Comments
 (0)