Skip to content

Commit 358d975

Browse files
committed
Align TensorMap creation and launch behavior with the latest review guidance.
Keep the public TMA entry point on StridedMemoryView and remove avoidable launch/build overhead so the reviewed API stays smaller without regressing local CUDA builds. Made-with: Cursor
1 parent 232b621 commit 358d975

5 files changed

Lines changed: 139 additions & 122 deletions

File tree

cuda_core/cuda/core/_cpp/tensor_map.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010
#include <exception>
1111

1212
#if defined(__has_include)
13+
// Older CTK releases do not ship <cuda/tma>. When it is unavailable we keep
14+
// the CCCL helper compiled out and fall back to the direct driver path.
1315
# if __has_include(<cuda/tma>)
1416
# include <cuda/tma>
1517
# define CUDA_CORE_HAS_CUDA_TMA 1
1618
# else
1719
# define CUDA_CORE_HAS_CUDA_TMA 0
1820
# endif
19-
# if __has_include(<dlpack/dlpack.h>)
21+
# if __has_include("dlpack.h")
22+
# include "dlpack.h"
23+
# define CUDA_CORE_HAS_DLPACK_H 1
24+
# elif __has_include(<dlpack/dlpack.h>)
2025
# include <dlpack/dlpack.h>
2126
# define CUDA_CORE_HAS_DLPACK_H 1
2227
# else

cuda_core/cuda/core/_kernel_arg_handler.pyx

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free
66
from libc.stdint cimport (intptr_t,
77
int8_t, int16_t, int32_t, int64_t,
88
uint8_t, uint16_t, uint32_t, uint64_t,)
9-
from libc.string cimport memcpy
109
from libcpp cimport bool as cpp_bool
1110
from libcpp.complex cimport complex as cpp_complex
1211
from libcpp cimport nullptr
@@ -135,18 +134,9 @@ cdef inline int prepare_tensor_map_arg(
135134
vector.vector[void*]& data_addresses,
136135
TensorMapDescriptor arg,
137136
const size_t idx) except -1:
138-
arg._check_context_compat()
139-
# Allocate a temporary buffer for the 128-byte CUtensorMap struct.
140-
# We copy rather than pointing directly at arg._tensor_map for lifetime
141-
# safety: ParamHolder owns and frees its argument buffers independently.
142-
cdef void* ptr = PyMem_Malloc(sizeof(cydriver.CUtensorMap))
143-
if ptr is NULL:
144-
raise MemoryError("Failed to allocate memory for CUtensorMap")
145-
memcpy(ptr, arg._get_data_ptr(), sizeof(cydriver.CUtensorMap))
146-
# data[idx] is tracked so the allocation is freed in ParamHolder.__dealloc__,
147-
# data_addresses[idx] is the pointer passed to cuLaunchKernel.
148-
data_addresses[idx] = ptr
149-
data[idx] = ptr
137+
# cuLaunchKernel copies argument bytes during launch, so a TensorMap
138+
# descriptor can point directly at its internal CUtensorMap storage.
139+
data_addresses[idx] = arg._get_data_ptr()
150140
return 0
151141

152142

@@ -299,9 +289,6 @@ cdef class ParamHolder:
299289
# it's a CUdeviceptr:
300290
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
301291
continue
302-
elif arg_type is tensor_map_descriptor_type:
303-
prepare_tensor_map_arg(self.data, self.data_addresses, <TensorMapDescriptor>arg, i)
304-
continue
305292
elif arg_type is bool:
306293
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
307294
continue
@@ -319,6 +306,9 @@ cdef class ParamHolder:
319306
elif arg_type is complex:
320307
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
321308
continue
309+
elif arg_type is tensor_map_descriptor_type:
310+
prepare_tensor_map_arg(self.data, self.data_addresses, <TensorMapDescriptor>arg, i)
311+
continue
322312

323313
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
324314
if not_prepared:

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ cdef class StridedMemoryView:
329329
):
330330
"""Create a tiled :obj:`TensorMapDescriptor` from this view.
331331
332-
This is a convenience wrapper around
333-
:meth:`cuda.core._tensor_map.TensorMapDescriptor.from_tiled`.
332+
This is the public entry point for creating tiled tensor map
333+
descriptors in ``cuda.core``.
334334
"""
335335
from cuda.core._tensor_map import TensorMapDescriptor
336336

@@ -347,7 +347,7 @@ cdef class StridedMemoryView:
347347
kwargs["l2_promotion"] = l2_promotion
348348
if oob_fill is not None:
349349
kwargs["oob_fill"] = oob_fill
350-
return TensorMapDescriptor.from_tiled(self, box_dim, **kwargs)
350+
return TensorMapDescriptor._from_tiled(self, box_dim, **kwargs)
351351

352352
def copy_from(
353353
self, other : StridedMemoryView, stream : Stream,

cuda_core/cuda/core/_tensor_map.pyx

Lines changed: 73 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ ELSE:
117117
"""Im2col wide mode for tensor map descriptors.
118118
119119
This enum is always defined for API stability, but the
120-
:meth:`TensorMapDescriptor.from_im2col_wide` factory requires a CUDA 13+
120+
:meth:`TensorMapDescriptor._from_im2col_wide` factory requires a CUDA 13+
121121
build and will raise otherwise.
122122
"""
123123
W = 0
@@ -163,10 +163,20 @@ def _resolve_data_type(view, data_type):
163163
"""Resolve the TMA data type from an explicit value or the view's dtype."""
164164

165165
if data_type is not None:
166-
if not isinstance(data_type, TensorMapDataType):
166+
if isinstance(data_type, TensorMapDataType):
167+
return data_type
168+
try:
169+
dt = numpy.dtype(data_type)
170+
except TypeError as e:
167171
raise TypeError(
168-
f"data_type must be a TensorMapDataType, got {type(data_type)}")
169-
return data_type
172+
"data_type must be a TensorMapDataType or a numpy/ml_dtypes dtype, "
173+
f"got {type(data_type)}") from e
174+
tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt)
175+
if tma_dt is None:
176+
raise ValueError(
177+
f"Unsupported dtype {dt} for TMA; "
178+
f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}.")
179+
return tma_dt
170180

171181
dt = view.dtype
172182
if dt is None:
@@ -243,23 +253,26 @@ cdef inline bint _tma_dtype_to_dlpack(
243253
return False
244254

245255

246-
def _get_validated_view(tensor):
247-
"""Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer."""
248-
if isinstance(tensor, StridedMemoryView):
249-
view = tensor
250-
else:
251-
# stream_ptr=-1: no stream synchronization needed because descriptor
252-
# creation only reads tensor metadata, it does not move data.
253-
view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1)
254-
256+
cdef inline int _validate_tensor_map_view(view) except -1:
255257
if not view.is_device_accessible:
256258
raise ValueError("The tensor must be device-accessible")
257259

258260
if view.ptr % 16 != 0:
259261
raise ValueError(
260262
f"Global memory address must be 16-byte aligned, "
261263
f"got address 0x{view.ptr:x}")
264+
return 0
262265

266+
267+
def _get_validated_view(tensor):
268+
"""Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer."""
269+
if isinstance(tensor, StridedMemoryView):
270+
view = tensor
271+
else:
272+
# stream_ptr=-1: no stream synchronization needed because descriptor
273+
# creation only reads tensor metadata, it does not move data.
274+
view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1)
275+
_validate_tensor_map_view(view)
263276
return view
264277

265278

@@ -292,6 +305,17 @@ cdef inline int _get_current_device_id() except -1:
292305
return <int>dev
293306

294307

308+
cdef inline int _require_view_device(
309+
view,
310+
int device_id,
311+
object caller,
312+
) except -1:
313+
if view.device_id != device_id:
314+
raise ValueError(
315+
f"{caller} expects tensor on device {device_id}, got {view.device_id}")
316+
return 0
317+
318+
295319
def _compute_byte_strides(shape, strides, elem_size):
296320
"""Compute byte strides from element strides or C-contiguous fallback.
297321
@@ -328,16 +352,17 @@ cdef class TensorMapDescriptor:
328352
used by the hardware TMA unit for efficient bulk data movement between
329353
global and shared memory.
330354
331-
Instances are created via the class methods :meth:`from_tiled` and
332-
:meth:`from_im2col`, and can be passed directly to
333-
:func:`~cuda.core.launch` as a kernel argument.
355+
Public tiled descriptors are created via
356+
:meth:`cuda.core.StridedMemoryView.as_tensor_map`. Specialized
357+
``_from_*`` helpers remain private while this API surface settles, and
358+
descriptors can be passed directly to :func:`~cuda.core.launch` as a
359+
kernel argument.
334360
"""
335361

336362
def __init__(self):
337363
raise RuntimeError(
338364
"TensorMapDescriptor cannot be instantiated directly. "
339-
"Use TensorMapDescriptor.from_tiled() or "
340-
"TensorMapDescriptor.from_im2col().")
365+
"Use StridedMemoryView.as_tensor_map() instead.")
341366

342367
cdef void* _get_data_ptr(self):
343368
return <void*>&self._tensor_map
@@ -364,22 +389,27 @@ cdef class TensorMapDescriptor:
364389
f"but current device is {current_dev_id}")
365390
return 0
366391

392+
@property
393+
def device(self):
394+
"""Return the :obj:`~cuda.core.Device` associated with this descriptor."""
395+
if self._device_id >= 0:
396+
from cuda.core._device import Device
397+
return Device(self._device_id)
398+
367399
@classmethod
368-
def from_tiled(cls, tensor, box_dim, *,
400+
def _from_tiled(cls, view, box_dim, *,
369401
element_strides=None,
370402
data_type=None,
371403
interleave=TensorMapInterleave.NONE,
372404
swizzle=TensorMapSwizzle.NONE,
373405
l2_promotion=TensorMapL2Promotion.NONE,
374406
oob_fill=TensorMapOOBFill.NONE):
375-
"""Create a tiled TMA descriptor from a tensor object.
407+
"""Create a tiled TMA descriptor from a validated view.
376408
377409
Parameters
378410
----------
379-
tensor : object
380-
Any object supporting DLPack or ``__cuda_array_interface__``,
381-
or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
382-
device-accessible memory with a 16-byte-aligned pointer.
411+
view : StridedMemoryView
412+
A device-accessible view with a 16-byte-aligned pointer.
383413
box_dim : tuple of int
384414
The size of each tile dimension (in elements). Must have the
385415
same rank as the tensor and each value must be in [1, 256].
@@ -411,15 +441,15 @@ cdef class TensorMapDescriptor:
411441
"""
412442
cdef TensorMapDescriptor desc = cls.__new__(cls)
413443

414-
view = _get_validated_view(tensor)
444+
_validate_tensor_map_view(view)
415445
# Keep both the original tensor object and the validated view alive.
416446
# For DLPack exporters, the view may hold the owning capsule whose
417447
# deleter can free the backing allocation when released.
418-
desc._source_ref = tensor
448+
desc._source_ref = view.exporting_obj
419449
desc._view_ref = view
420450
desc._context = _get_current_context_ptr()
421451
desc._device_id = _get_current_device_id()
422-
_require_view_device(view, desc._device_id, "TensorMapDescriptor.from_tiled")
452+
_require_view_device(view, desc._device_id, "TensorMapDescriptor._from_tiled")
423453

424454
tma_dt = _resolve_data_type(view, data_type)
425455
cdef int c_data_type_int = int(tma_dt)
@@ -591,24 +621,22 @@ cdef class TensorMapDescriptor:
591621
return desc
592622

593623
@classmethod
594-
def from_im2col(cls, tensor, pixel_box_lower_corner, pixel_box_upper_corner,
624+
def _from_im2col(cls, view, pixel_box_lower_corner, pixel_box_upper_corner,
595625
channels_per_pixel, pixels_per_column, *,
596626
element_strides=None,
597627
data_type=None,
598628
interleave=TensorMapInterleave.NONE,
599629
swizzle=TensorMapSwizzle.NONE,
600630
l2_promotion=TensorMapL2Promotion.NONE,
601631
oob_fill=TensorMapOOBFill.NONE):
602-
"""Create an im2col TMA descriptor from a tensor object.
632+
"""Create an im2col TMA descriptor from a validated view.
603633
604634
Im2col layout is used for convolution-style data access patterns.
605635
606636
Parameters
607637
----------
608-
tensor : object
609-
Any object supporting DLPack or ``__cuda_array_interface__``,
610-
or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
611-
device-accessible memory with a 16-byte-aligned pointer.
638+
view : StridedMemoryView
639+
A device-accessible view with a 16-byte-aligned pointer.
612640
pixel_box_lower_corner : tuple of int
613641
Lower corner of the pixel bounding box for each spatial
614642
dimension (rank - 2 elements). Specified in row-major order
@@ -647,12 +675,12 @@ cdef class TensorMapDescriptor:
647675
"""
648676
cdef TensorMapDescriptor desc = cls.__new__(cls)
649677

650-
view = _get_validated_view(tensor)
651-
desc._source_ref = tensor
678+
_validate_tensor_map_view(view)
679+
desc._source_ref = view.exporting_obj
652680
desc._view_ref = view
653681
desc._context = _get_current_context_ptr()
654682
desc._device_id = _get_current_device_id()
655-
_require_view_device(view, desc._device_id, "TensorMapDescriptor.from_im2col")
683+
_require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col")
656684

657685
tma_dt = _resolve_data_type(view, data_type)
658686
cdef int c_data_type_int = int(tma_dt)
@@ -746,7 +774,7 @@ cdef class TensorMapDescriptor:
746774
return desc
747775

748776
@classmethod
749-
def from_im2col_wide(cls, tensor, pixel_box_lower_corner_width, pixel_box_upper_corner_width,
777+
def _from_im2col_wide(cls, view, pixel_box_lower_corner_width, pixel_box_upper_corner_width,
750778
channels_per_pixel, pixels_per_column, *,
751779
element_strides=None,
752780
data_type=None,
@@ -755,18 +783,16 @@ cdef class TensorMapDescriptor:
755783
swizzle=TensorMapSwizzle.SWIZZLE_128B,
756784
l2_promotion=TensorMapL2Promotion.NONE,
757785
oob_fill=TensorMapOOBFill.NONE):
758-
"""Create an im2col-wide TMA descriptor from a tensor object.
786+
"""Create an im2col-wide TMA descriptor from a validated view.
759787
760788
Im2col-wide layout loads elements exclusively along the W (width)
761789
dimension. This variant is supported on compute capability 10.0+
762790
(Blackwell and later).
763791
764792
Parameters
765793
----------
766-
tensor : object
767-
Any object supporting DLPack or ``__cuda_array_interface__``,
768-
or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
769-
device-accessible memory with a 16-byte-aligned pointer.
794+
view : StridedMemoryView
795+
A device-accessible view with a 16-byte-aligned pointer.
770796
pixel_box_lower_corner_width : int
771797
Lower corner of the pixel bounding box along the W dimension.
772798
pixel_box_upper_corner_width : int
@@ -803,16 +829,16 @@ cdef class TensorMapDescriptor:
803829
"""
804830
IF CUDA_CORE_BUILD_MAJOR < 13:
805831
raise RuntimeError(
806-
"TensorMapDescriptor.from_im2col_wide requires a CUDA 13+ build")
832+
"TensorMapDescriptor._from_im2col_wide requires a CUDA 13+ build")
807833
ELSE:
808834
cdef TensorMapDescriptor desc = cls.__new__(cls)
809835

810-
view = _get_validated_view(tensor)
811-
desc._source_ref = tensor
836+
_validate_tensor_map_view(view)
837+
desc._source_ref = view.exporting_obj
812838
desc._view_ref = view
813839
desc._context = _get_current_context_ptr()
814840
desc._device_id = _get_current_device_id()
815-
_require_view_device(view, desc._device_id, "TensorMapDescriptor.from_im2col_wide")
841+
_require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col_wide")
816842

817843
tma_dt = _resolve_data_type(view, data_type)
818844
cdef int c_data_type_int = int(tma_dt)
@@ -917,7 +943,7 @@ cdef class TensorMapDescriptor:
917943
# Update the source reference only after the driver call succeeds,
918944
# so we don't drop the old tensor (risking a dangling pointer in the
919945
# CUtensorMap struct) if the call fails.
920-
self._source_ref = tensor
946+
self._source_ref = view.exporting_obj
921947
self._view_ref = view
922948

923949
def __repr__(self):

0 commit comments

Comments
 (0)