Skip to content

Commit 719f0f3

Browse files
committed
Bundle tiled TensorMap options and type retained views.
Centralize the tiled descriptor arguments in an options object, keep dtype-like inputs on the public path while using raw driver values internally, and declare StridedMemoryView in a pxd so retained views stay typed without extra helper indirection. Made-with: Cursor
1 parent 9ff8d0f commit 719f0f3

5 files changed

Lines changed: 276 additions & 93 deletions

File tree

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from libc.stdint cimport intptr_t
2+
3+
from cuda.core._dlpack cimport DLTensor
4+
from cuda.core._layout cimport _StridedLayout
5+
6+
7+
cdef class StridedMemoryView:
8+
cdef readonly:
9+
intptr_t ptr
10+
int device_id
11+
bint is_device_accessible
12+
bint readonly
13+
object exporting_obj
14+
15+
cdef:
16+
object metadata
17+
DLTensor* dl_tensor
18+
_StridedLayout _layout
19+
object _buffer
20+
object _dtype
21+
22+
cdef inline _StridedLayout get_layout(self)
23+
cdef inline object get_buffer(self)
24+
cdef inline object get_dtype(self)

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,35 +107,6 @@ cdef class StridedMemoryView:
107107
it will be the Buffer instance passed to the method.
108108
109109
"""
110-
cdef readonly:
111-
intptr_t ptr
112-
int device_id
113-
bint is_device_accessible
114-
bint readonly
115-
object exporting_obj
116-
117-
cdef:
118-
# If using dlpack, this is a strong reference to the result of
119-
# obj.__dlpack__() so we can lazily create shape and strides from
120-
# it later. If using CAI, this is a reference to the source
121-
# `__cuda_array_interface__` object.
122-
object metadata
123-
124-
# The tensor object if has obj has __dlpack__, otherwise must be NULL
125-
DLTensor *dl_tensor
126-
127-
# Memoized properties
128-
# Either lazily inferred from dl_tensor/metadata,
129-
# or explicitly provided if created with from_buffer().
130-
_StridedLayout _layout
131-
# Either exporting_obj if it is a Buffer, otherwise a Buffer instance
132-
# with owner set to the exporting object.
133-
object _buffer
134-
# Either lazily inferred from dl_tensor/metadata,
135-
# or explicitly provided if created with from_buffer().
136-
# In the latter case, it can be None.
137-
object _dtype
138-
139110
def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
140111
cdef str clsname = self.__class__.__name__
141112
if obj is not None:
@@ -318,8 +289,9 @@ cdef class StridedMemoryView:
318289

319290
def as_tensor_map(
320291
self,
321-
box_dim,
292+
box_dim=None,
322293
*,
294+
options=None,
323295
element_strides=None,
324296
data_type=None,
325297
interleave=None,
@@ -330,11 +302,15 @@ cdef class StridedMemoryView:
330302
"""Create a tiled :obj:`TensorMapDescriptor` from this view.
331303
332304
This is the public entry point for creating tiled tensor map
333-
descriptors in ``cuda.core``.
305+
descriptors in ``cuda.core``. Pass either ``box_dim`` and the
306+
individual keyword arguments directly, or provide bundled tiled
307+
options via ``options=``.
334308
"""
335309
from cuda.core._tensor_map import TensorMapDescriptor
336310

337311
kwargs = {}
312+
if options is not None:
313+
kwargs["options"] = options
338314
if element_strides is not None:
339315
kwargs["element_strides"] = element_strides
340316
if data_type is not None:

cuda_core/cuda/core/_tensor_map.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from cuda.bindings cimport cydriver
66
from libc.stdint cimport intptr_t
7+
from cuda.core._memoryview cimport StridedMemoryView
78

89

910
cdef class TensorMapDescriptor:
1011
cdef cydriver.CUtensorMap _tensor_map
1112
cdef int _device_id
1213
cdef intptr_t _context
1314
cdef object _source_ref
14-
cdef object _view_ref
15+
cdef StridedMemoryView _view_ref
1516
cdef object _repr_info
1617

1718
cdef int _check_context_compat(self) except -1

0 commit comments

Comments
 (0)