@@ -117,7 +117,7 @@ ELSE:
117117 """ Im2col wide mode for tensor map descriptors.
118118
119119 This enum is always defined for API stability, but the
120- :meth:`TensorMapDescriptor.from_im2col_wide ` factory requires a CUDA 13+
120+ :meth:`TensorMapDescriptor._from_im2col_wide ` factory requires a CUDA 13+
121121 build and will raise otherwise.
122122 """
123123 W = 0
@@ -163,10 +163,20 @@ def _resolve_data_type(view, data_type):
163163 """ Resolve the TMA data type from an explicit value or the view's dtype."""
164164
165165 if data_type is not None :
166- if not isinstance (data_type, TensorMapDataType):
166+ if isinstance (data_type, TensorMapDataType):
167+ return data_type
168+ try :
169+ dt = numpy.dtype(data_type)
170+ except TypeError as e:
167171 raise TypeError (
168- f" data_type must be a TensorMapDataType, got {type(data_type)}" )
169- return data_type
172+ " data_type must be a TensorMapDataType or a numpy/ml_dtypes dtype, "
173+ f" got {type(data_type)}" ) from e
174+ tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt)
175+ if tma_dt is None :
176+ raise ValueError (
177+ f" Unsupported dtype {dt} for TMA; "
178+ f" supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}." )
179+ return tma_dt
170180
171181 dt = view.dtype
172182 if dt is None :
@@ -243,23 +253,26 @@ cdef inline bint _tma_dtype_to_dlpack(
243253 return False
244254
245255
246- def _get_validated_view (tensor ):
247- """ Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer."""
248- if isinstance (tensor, StridedMemoryView):
249- view = tensor
250- else :
251- # stream_ptr=-1: no stream synchronization needed because descriptor
252- # creation only reads tensor metadata, it does not move data.
253- view = StridedMemoryView.from_any_interface(tensor, stream_ptr = - 1 )
254-
256+ cdef inline int _validate_tensor_map_view(view) except - 1 :
255257 if not view.is_device_accessible:
256258 raise ValueError (" The tensor must be device-accessible" )
257259
258260 if view.ptr % 16 != 0 :
259261 raise ValueError (
260262 f" Global memory address must be 16-byte aligned, "
261263 f" got address 0x{view.ptr:x}" )
264+ return 0
262265
266+
267+ def _get_validated_view (tensor ):
268+ """ Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer."""
269+ if isinstance (tensor, StridedMemoryView):
270+ view = tensor
271+ else :
272+ # stream_ptr=-1: no stream synchronization needed because descriptor
273+ # creation only reads tensor metadata, it does not move data.
274+ view = StridedMemoryView.from_any_interface(tensor, stream_ptr = - 1 )
275+ _validate_tensor_map_view(view)
263276 return view
264277
265278
@@ -292,6 +305,17 @@ cdef inline int _get_current_device_id() except -1:
292305 return < int > dev
293306
294307
308+ cdef inline int _require_view_device(
309+ view,
310+ int device_id,
311+ object caller,
312+ ) except - 1 :
313+ if view.device_id != device_id:
314+ raise ValueError (
315+ f" {caller} expects tensor on device {device_id}, got {view.device_id}" )
316+ return 0
317+
318+
295319def _compute_byte_strides (shape , strides , elem_size ):
296320 """ Compute byte strides from element strides or C-contiguous fallback.
297321
@@ -328,16 +352,17 @@ cdef class TensorMapDescriptor:
328352 used by the hardware TMA unit for efficient bulk data movement between
329353 global and shared memory.
330354
331- Instances are created via the class methods :meth:`from_tiled` and
332- :meth:`from_im2col`, and can be passed directly to
333- :func:`~cuda.core.launch` as a kernel argument.
355+ Public tiled descriptors are created via
356+ :meth:`cuda.core.StridedMemoryView.as_tensor_map`. Specialized
357+ ``_from_*`` helpers remain private while this API surface settles, and
358+ descriptors can be passed directly to :func:`~cuda.core.launch` as a
359+ kernel argument.
334360 """
335361
336362 def __init__ (self ):
337363 raise RuntimeError (
338364 " TensorMapDescriptor cannot be instantiated directly. "
339- " Use TensorMapDescriptor.from_tiled() or "
340- " TensorMapDescriptor.from_im2col()." )
365+ " Use StridedMemoryView.as_tensor_map() instead." )
341366
342367 cdef void * _get_data_ptr(self ):
343368 return < void * > & self ._tensor_map
@@ -364,22 +389,27 @@ cdef class TensorMapDescriptor:
364389 f" but current device is {current_dev_id}" )
365390 return 0
366391
392+ @property
393+ def device (self ):
394+ """ Return the :obj:`~cuda.core.Device` associated with this descriptor."""
395+ if self ._device_id >= 0 :
396+ from cuda.core._device import Device
397+ return Device(self ._device_id)
398+
367399 @classmethod
368- def from_tiled (cls , tensor , box_dim , *,
400+ def _from_tiled (cls , view , box_dim , *,
369401 element_strides = None ,
370402 data_type = None ,
371403 interleave = TensorMapInterleave.NONE,
372404 swizzle = TensorMapSwizzle.NONE,
373405 l2_promotion = TensorMapL2Promotion.NONE,
374406 oob_fill = TensorMapOOBFill.NONE):
375- """ Create a tiled TMA descriptor from a tensor object .
407+ """ Create a tiled TMA descriptor from a validated view .
376408
377409 Parameters
378410 ----------
379- tensor : object
380- Any object supporting DLPack or ``__cuda_array_interface__``,
381- or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
382- device-accessible memory with a 16-byte-aligned pointer.
411+ view : StridedMemoryView
412+ A device-accessible view with a 16-byte-aligned pointer.
383413 box_dim : tuple of int
384414 The size of each tile dimension (in elements). Must have the
385415 same rank as the tensor and each value must be in [1, 256].
@@ -411,15 +441,15 @@ cdef class TensorMapDescriptor:
411441 """
412442 cdef TensorMapDescriptor desc = cls .__new__ (cls )
413443
414- view = _get_validated_view(tensor )
444+ _validate_tensor_map_view(view )
415445 # Keep both the original tensor object and the validated view alive.
416446 # For DLPack exporters, the view may hold the owning capsule whose
417447 # deleter can free the backing allocation when released.
418- desc._source_ref = tensor
448+ desc._source_ref = view.exporting_obj
419449 desc._view_ref = view
420450 desc._context = _get_current_context_ptr()
421451 desc._device_id = _get_current_device_id()
422- _require_view_device(view, desc._device_id, " TensorMapDescriptor.from_tiled " )
452+ _require_view_device(view, desc._device_id, " TensorMapDescriptor._from_tiled " )
423453
424454 tma_dt = _resolve_data_type(view, data_type)
425455 cdef int c_data_type_int = int (tma_dt)
@@ -591,24 +621,22 @@ cdef class TensorMapDescriptor:
591621 return desc
592622
593623 @classmethod
594- def from_im2col (cls , tensor , pixel_box_lower_corner , pixel_box_upper_corner ,
624+ def _from_im2col (cls , view , pixel_box_lower_corner , pixel_box_upper_corner ,
595625 channels_per_pixel , pixels_per_column , *,
596626 element_strides = None ,
597627 data_type = None ,
598628 interleave = TensorMapInterleave.NONE,
599629 swizzle = TensorMapSwizzle.NONE,
600630 l2_promotion = TensorMapL2Promotion.NONE,
601631 oob_fill = TensorMapOOBFill.NONE):
602- """ Create an im2col TMA descriptor from a tensor object .
632+ """ Create an im2col TMA descriptor from a validated view .
603633
604634 Im2col layout is used for convolution-style data access patterns.
605635
606636 Parameters
607637 ----------
608- tensor : object
609- Any object supporting DLPack or ``__cuda_array_interface__``,
610- or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
611- device-accessible memory with a 16-byte-aligned pointer.
638+ view : StridedMemoryView
639+ A device-accessible view with a 16-byte-aligned pointer.
612640 pixel_box_lower_corner : tuple of int
613641 Lower corner of the pixel bounding box for each spatial
614642 dimension (rank - 2 elements). Specified in row-major order
@@ -647,12 +675,12 @@ cdef class TensorMapDescriptor:
647675 """
648676 cdef TensorMapDescriptor desc = cls .__new__ (cls )
649677
650- view = _get_validated_view(tensor )
651- desc._source_ref = tensor
678+ _validate_tensor_map_view(view )
679+ desc._source_ref = view.exporting_obj
652680 desc._view_ref = view
653681 desc._context = _get_current_context_ptr()
654682 desc._device_id = _get_current_device_id()
655- _require_view_device(view, desc._device_id, " TensorMapDescriptor.from_im2col " )
683+ _require_view_device(view, desc._device_id, " TensorMapDescriptor._from_im2col " )
656684
657685 tma_dt = _resolve_data_type(view, data_type)
658686 cdef int c_data_type_int = int (tma_dt)
@@ -746,7 +774,7 @@ cdef class TensorMapDescriptor:
746774 return desc
747775
748776 @classmethod
749- def from_im2col_wide (cls , tensor , pixel_box_lower_corner_width , pixel_box_upper_corner_width ,
777+ def _from_im2col_wide (cls , view , pixel_box_lower_corner_width , pixel_box_upper_corner_width ,
750778 channels_per_pixel , pixels_per_column , *,
751779 element_strides = None ,
752780 data_type = None ,
@@ -755,18 +783,16 @@ cdef class TensorMapDescriptor:
755783 swizzle = TensorMapSwizzle.SWIZZLE_128B,
756784 l2_promotion = TensorMapL2Promotion.NONE,
757785 oob_fill = TensorMapOOBFill.NONE):
758- """ Create an im2col-wide TMA descriptor from a tensor object .
786+ """ Create an im2col-wide TMA descriptor from a validated view .
759787
760788 Im2col-wide layout loads elements exclusively along the W (width)
761789 dimension. This variant is supported on compute capability 10.0+
762790 (Blackwell and later).
763791
764792 Parameters
765793 ----------
766- tensor : object
767- Any object supporting DLPack or ``__cuda_array_interface__``,
768- or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
769- device-accessible memory with a 16-byte-aligned pointer.
794+ view : StridedMemoryView
795+ A device-accessible view with a 16-byte-aligned pointer.
770796 pixel_box_lower_corner_width : int
771797 Lower corner of the pixel bounding box along the W dimension.
772798 pixel_box_upper_corner_width : int
@@ -803,16 +829,16 @@ cdef class TensorMapDescriptor:
803829 """
804830 IF CUDA_CORE_BUILD_MAJOR < 13 :
805831 raise RuntimeError (
806- " TensorMapDescriptor.from_im2col_wide requires a CUDA 13+ build" )
832+ " TensorMapDescriptor._from_im2col_wide requires a CUDA 13+ build" )
807833 ELSE :
808834 cdef TensorMapDescriptor desc = cls .__new__ (cls )
809835
810- view = _get_validated_view(tensor )
811- desc._source_ref = tensor
836+ _validate_tensor_map_view(view )
837+ desc._source_ref = view.exporting_obj
812838 desc._view_ref = view
813839 desc._context = _get_current_context_ptr()
814840 desc._device_id = _get_current_device_id()
815- _require_view_device(view, desc._device_id, " TensorMapDescriptor.from_im2col_wide " )
841+ _require_view_device(view, desc._device_id, " TensorMapDescriptor._from_im2col_wide " )
816842
817843 tma_dt = _resolve_data_type(view, data_type)
818844 cdef int c_data_type_int = int (tma_dt)
@@ -917,7 +943,7 @@ cdef class TensorMapDescriptor:
917943 # Update the source reference only after the driver call succeeds,
918944 # so we don't drop the old tensor (risking a dangling pointer in the
919945 # CUtensorMap struct) if the call fails.
920- self ._source_ref = tensor
946+ self ._source_ref = view.exporting_obj
921947 self ._view_ref = view
922948
923949 def __repr__ (self ):
0 commit comments