Skip to content

Commit 892ee60

Browse files
committed
Align TensorMap API surface with review feedback and enforce context safety.
Expose only TensorMapDescriptor in cuda.core, add StridedMemoryView.as_tensor_map(), remove redundant tensor-map fallback packing, and track/check descriptor context/device compatibility before replacement and kernel launch argument packing. Made-with: Cursor
1 parent 1a6b416 commit 892ee60

9 files changed

Lines changed: 107 additions & 23 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ __pycache__/
1414
!*_impl.cpp
1515
!cuda_bindings/cuda/bindings/_lib/param_packer.cpp
1616
!cuda_bindings/cuda/bindings/_bindings/loader.cpp
17-
!cuda_core/cuda/core/_cpp/*.cpp
1817
cache_driver
1918
cache_runtime
2019
cache_nvrtc

cuda_core/cuda/core/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,4 @@
6868
Stream,
6969
StreamOptions,
7070
)
71-
from cuda.core._tensor_map import (
72-
TensorMapDataType,
73-
TensorMapDescriptor,
74-
TensorMapIm2ColWideMode,
75-
TensorMapInterleave,
76-
TensorMapL2Promotion,
77-
TensorMapOOBFill,
78-
TensorMapSwizzle,
79-
)
71+
from cuda.core._tensor_map import TensorMapDescriptor

cuda_core/cuda/core/_kernel_arg_handler.pyx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ cdef inline int prepare_tensor_map_arg(
135135
vector.vector[void*]& data_addresses,
136136
TensorMapDescriptor arg,
137137
const size_t idx) except -1:
138+
arg._check_context_compat()
138139
# Allocate a temporary buffer for the 128-byte CUtensorMap struct.
139140
# We copy rather than pointing directly at arg._tensor_map for lifetime
140141
# safety: ParamHolder owns and frees its argument buffers independently.
@@ -350,9 +351,6 @@ cdef class ParamHolder:
350351
elif isinstance(arg, driver.CUgraphConditionalHandle):
351352
prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i)
352353
continue
353-
elif isinstance(arg, tensor_map_descriptor_type):
354-
prepare_tensor_map_arg(self.data, self.data_addresses, <TensorMapDescriptor>arg, i)
355-
continue
356354
# TODO: support ctypes/numpy struct
357355
raise TypeError("the argument is of unsupported type: " + str(type(arg)))
358356

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,39 @@ cdef class StridedMemoryView:
316316
view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly)
317317
return view
318318

319+
def as_tensor_map(
320+
self,
321+
box_dim,
322+
*,
323+
element_strides=None,
324+
data_type=None,
325+
interleave=None,
326+
swizzle=None,
327+
l2_promotion=None,
328+
oob_fill=None,
329+
):
330+
"""Create a tiled :obj:`TensorMapDescriptor` from this view.
331+
332+
This is a convenience wrapper around
333+
:meth:`cuda.core._tensor_map.TensorMapDescriptor.from_tiled`.
334+
"""
335+
from cuda.core._tensor_map import TensorMapDescriptor
336+
337+
kwargs = {}
338+
if element_strides is not None:
339+
kwargs["element_strides"] = element_strides
340+
if data_type is not None:
341+
kwargs["data_type"] = data_type
342+
if interleave is not None:
343+
kwargs["interleave"] = interleave
344+
if swizzle is not None:
345+
kwargs["swizzle"] = swizzle
346+
if l2_promotion is not None:
347+
kwargs["l2_promotion"] = l2_promotion
348+
if oob_fill is not None:
349+
kwargs["oob_fill"] = oob_fill
350+
return TensorMapDescriptor.from_tiled(self, box_dim, **kwargs)
351+
319352
def copy_from(
320353
self, other : StridedMemoryView, stream : Stream,
321354
allocator = None,

cuda_core/cuda/core/_tensor_map.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from cuda.bindings cimport cydriver
6+
from libc.stdint cimport intptr_t
67

78

89
cdef class TensorMapDescriptor:
910
cdef cydriver.CUtensorMap _tensor_map
11+
cdef int _device_id
12+
cdef intptr_t _context
1013
cdef object _source_ref
1114
cdef object _view_ref
1215
cdef object _repr_info
1316

17+
cdef int _check_context_compat(self) except -1
1418
cdef void* _get_data_ptr(self)

cuda_core/cuda/core/_tensor_map.pyx

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,22 @@ def _get_validated_view(tensor):
263263
return view
264264

265265

266+
cdef inline intptr_t _get_current_context_ptr() except? 0:
267+
cdef cydriver.CUcontext ctx
268+
with nogil:
269+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
270+
if ctx == NULL:
271+
raise RuntimeError("TensorMapDescriptor requires an active CUDA context")
272+
return <intptr_t>ctx
273+
274+
275+
cdef inline int _get_current_device_id() except -1:
276+
cdef cydriver.CUdevice dev
277+
with nogil:
278+
HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev))
279+
return <int>dev
280+
281+
266282
def _compute_byte_strides(shape, strides, elem_size):
267283
"""Compute byte strides from element strides or C-contiguous fallback.
268284
@@ -313,6 +329,28 @@ cdef class TensorMapDescriptor:
313329
cdef void* _get_data_ptr(self):
314330
return <void*>&self._tensor_map
315331

332+
cdef int _check_context_compat(self) except -1:
333+
cdef cydriver.CUcontext current_ctx
334+
cdef cydriver.CUdevice current_dev
335+
cdef int current_dev_id
336+
if self._context == 0 and self._device_id < 0:
337+
return 0
338+
with nogil:
339+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&current_ctx))
340+
if current_ctx == NULL:
341+
raise RuntimeError("TensorMapDescriptor requires an active CUDA context")
342+
if self._context != 0 and <intptr_t>current_ctx != self._context:
343+
raise RuntimeError(
344+
"TensorMapDescriptor was created in a different CUDA context")
345+
with nogil:
346+
HANDLE_RETURN(cydriver.cuCtxGetDevice(&current_dev))
347+
current_dev_id = <int>current_dev
348+
if self._device_id >= 0 and current_dev_id != self._device_id:
349+
raise RuntimeError(
350+
f"TensorMapDescriptor belongs to device {self._device_id}, "
351+
f"but current device is {current_dev_id}")
352+
return 0
353+
316354
@classmethod
317355
def from_tiled(cls, tensor, box_dim, *,
318356
element_strides=None,
@@ -366,6 +404,8 @@ cdef class TensorMapDescriptor:
366404
# deleter can free the backing allocation when released.
367405
desc._source_ref = tensor
368406
desc._view_ref = view
407+
desc._context = _get_current_context_ptr()
408+
desc._device_id = _get_current_device_id()
369409

370410
tma_dt = _resolve_data_type(view, data_type)
371411
cdef int c_data_type_int = int(tma_dt)
@@ -593,6 +633,8 @@ cdef class TensorMapDescriptor:
593633
view = _get_validated_view(tensor)
594634
desc._source_ref = tensor
595635
desc._view_ref = view
636+
desc._context = _get_current_context_ptr()
637+
desc._device_id = _get_current_device_id()
596638

597639
tma_dt = _resolve_data_type(view, data_type)
598640
cdef int c_data_type_int = int(tma_dt)
@@ -750,6 +792,8 @@ cdef class TensorMapDescriptor:
750792
view = _get_validated_view(tensor)
751793
desc._source_ref = tensor
752794
desc._view_ref = view
795+
desc._context = _get_current_context_ptr()
796+
desc._device_id = _get_current_device_id()
753797

754798
tma_dt = _resolve_data_type(view, data_type)
755799
cdef int c_data_type_int = int(tma_dt)
@@ -839,7 +883,11 @@ cdef class TensorMapDescriptor:
839883
or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
840884
device-accessible memory with a 16-byte-aligned pointer.
841885
"""
886+
self._check_context_compat()
842887
view = _get_validated_view(tensor)
888+
if view.device_id != self._device_id:
889+
raise ValueError(
890+
f"replace_address expects tensor on device {self._device_id}, got {view.device_id}")
843891

844892
cdef intptr_t global_address = view.ptr
845893

cuda_core/examples/tma_replace_address.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
LaunchConfig,
3737
Program,
3838
ProgramOptions,
39-
TensorMapDescriptor,
39+
StridedMemoryView,
4040
launch,
4141
)
4242

@@ -159,7 +159,7 @@
159159
output = cp.zeros(N, dtype=cp.float32)
160160
dev.sync() # cupy uses its own stream
161161

162-
tensor_map = TensorMapDescriptor.from_tiled(a, box_dim=(TILE_SIZE,))
162+
tensor_map = StridedMemoryView.from_any_interface(a, stream_ptr=-1).as_tensor_map(box_dim=(TILE_SIZE,))
163163

164164
n_tiles = N // TILE_SIZE
165165
config = LaunchConfig(grid=n_tiles, block=TILE_SIZE)

cuda_core/examples/tma_tensor_map.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
LaunchConfig,
3232
Program,
3333
ProgramOptions,
34-
TensorMapDescriptor,
34+
StridedMemoryView,
3535
launch,
3636
)
3737

@@ -48,8 +48,6 @@
4848
sys.exit(0)
4949
dev.set_current()
5050

51-
arch_str = "".join(f"{i}" for i in arch)
52-
5351
# ---------------------------------------------------------------------------
5452
# CUDA kernel that uses TMA to load a 1-D tile into shared memory, then
5553
# copies the tile to an output buffer so we can verify correctness.
@@ -141,7 +139,7 @@
141139
prog = Program(
142140
code,
143141
code_type="c++",
144-
options=ProgramOptions(std="c++17", arch=f"sm_{arch_str}"),
142+
options=ProgramOptions(std="c++17", arch=f"sm_{dev.arch}"),
145143
)
146144
mod = prog.compile("cubin")
147145
ker = mod.get_kernel("tma_copy")
@@ -155,11 +153,10 @@
155153
dev.sync() # cupy uses its own stream
156154

157155
# ---------------------------------------------------------------------------
158-
# 2) Create a TMA tiled descriptor
159-
# from_tiled() accepts any DLPack / __cuda_array_interface__ object.
156+
# 2) Create a TMA tiled descriptor from a StridedMemoryView.
160157
# The dtype (float32) is inferred automatically from the CuPy array.
161158
# ---------------------------------------------------------------------------
162-
tensor_map = TensorMapDescriptor.from_tiled(a, box_dim=(TILE_SIZE,))
159+
tensor_map = StridedMemoryView.from_any_interface(a, stream_ptr=-1).as_tensor_map(box_dim=(TILE_SIZE,))
163160

164161
# ---------------------------------------------------------------------------
165162
# 3) Launch the kernel

cuda_core/tests/test_tensor_map.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
from cuda.core import (
88
Device,
9-
TensorMapDataType,
9+
StridedMemoryView,
1010
TensorMapDescriptor,
11+
)
12+
from cuda.core._tensor_map import (
13+
TensorMapDataType,
1114
TensorMapIm2ColWideMode,
1215
TensorMapInterleave,
1316
TensorMapL2Promotion,
@@ -107,6 +110,16 @@ def test_from_tiled_2d(self, dev, skip_if_no_tma):
107110
)
108111
assert desc is not None
109112

113+
def test_strided_memory_view_as_tensor_map(self, dev, skip_if_no_tma):
114+
buf = dev.allocate(64 * 64 * 4)
115+
tensor = _DeviceArray(buf, (64, 64))
116+
view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1)
117+
desc = view.as_tensor_map(
118+
box_dim=(32, 32),
119+
data_type=TensorMapDataType.FLOAT32,
120+
)
121+
assert desc is not None
122+
110123
def test_from_tiled_3d(self, dev, skip_if_no_tma):
111124
buf = dev.allocate(16 * 16 * 16 * 4) # 16x16x16 float32
112125
tensor = _DeviceArray(buf, (16, 16, 16))

0 commit comments

Comments
 (0)