55from ._dlpack cimport *
66
77import functools
8+ import warnings
89from typing import Optional
910
1011import numpy
@@ -78,30 +79,78 @@ cdef class StridedMemoryView:
7879 bint readonly
7980 object exporting_obj
8081
81- # If using dlpack, this is a strong reference to the result of
82- # obj.__dlpack__() so we can lazily create shape and strides from
83- # it later. If using CAI, this is a reference to the source
84- # `__cuda_array_interface__` object.
85- cdef object metadata
86-
87- # The tensor object if has obj has __dlpack__, otherwise must be NULL
88- cdef DLTensor * dl_tensor
89-
90- # Memoized properties
91- cdef tuple _shape
92- cdef tuple _strides
93- cdef bint _strides_init # Has the strides tuple been init'ed?
94- cdef object _dtype
95-
96- def __init__ (self , obj = None , stream_ptr = None ):
82+ cdef:
83+ # If using dlpack, this is a strong reference to the result of
84+ # obj.__dlpack__() so we can lazily create shape and strides from
85+ # it later. If using CAI, this is a reference to the source
86+ # `__cuda_array_interface__` object.
87+ object metadata
88+
89+ # The tensor object if has obj has __dlpack__, otherwise must be NULL
90+ DLTensor * dl_tensor
91+
92+ # Memoized properties
93+ tuple _shape
94+ tuple _strides
95+ # a `None` value for _strides has defined meaning in dlpack and
96+ # the cuda array interface, meaning C order, contiguous.
97+ #
98+ # this flag helps prevent unnecessary recompuation of _strides
99+ bint _strides_init
100+ object _dtype
101+
102+ def __init__ (self , obj: object = None , stream_ptr: int | None = None ) -> None:
103+ cdef str clsname = self .__class__.__name__
97104 if obj is not None:
98105 # populate self's attributes
99106 if check_has_dlpack(obj ):
107+ warnings.warn(
108+ f" Constructing a {clsname} directly from a DLPack-supporting object is deprecated; "
109+ " Use `StridedMemoryView.from_dlpack` or `StridedMemoryView.from_any_interface` instead." ,
110+ DeprecationWarning ,
111+ stacklevel = 2 ,
112+ )
100113 view_as_dlpack(obj, stream_ptr, self )
101114 else :
115+ warnings.warn(
116+ f" Constructing a {clsname} directly from a CUDA-array-interface-supporting object is deprecated; "
117+ " Use `StridedMemoryView.from_cuda_array_interface` or `StridedMemoryView.from_any_interface` instead." ,
118+ DeprecationWarning ,
119+ stacklevel = 2 ,
120+ )
102121 view_as_cai(obj, stream_ptr, self )
103122 else :
104- pass
123+ warnings.warn(
124+ f" Constructing an empty {clsname} is deprecated; "
125+ " use one of the classmethods `from_dlpack`, `from_cuda_array_interface` or `from_any_interface` "
126+ " to construct a StridedMemoryView from an object" ,
127+ DeprecationWarning ,
128+ stacklevel = 2 ,
129+ )
130+
131+ @classmethod
132+ def from_dlpack (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
133+ cdef StridedMemoryView buf
134+ with warnings.catch_warnings():
135+ warnings.simplefilter(" ignore" )
136+ buf = cls ()
137+ view_as_dlpack(obj, stream_ptr, buf)
138+ return buf
139+
140+ @classmethod
141+ def from_cuda_array_interface (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
142+ cdef StridedMemoryView buf
143+ with warnings.catch_warnings():
144+ warnings.simplefilter(" ignore" )
145+ buf = cls ()
146+ view_as_cai(obj, stream_ptr, buf)
147+ return buf
148+
149+ @classmethod
150+ def from_any_interface (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
151+ if check_has_dlpack(obj ):
152+ return cls .from_dlpack(obj, stream_ptr)
153+ return cls .from_cuda_array_interface(obj, stream_ptr)
105154
106155 def __dealloc__ (self ):
107156 if self .dl_tensor == NULL :
@@ -121,7 +170,7 @@ cdef class StridedMemoryView:
121170 dlm_tensor.deleter(dlm_tensor)
122171
123172 @property
124- def shape (self ) -> tuple[int]:
173+ def shape (self ) -> tuple[int , ... ]:
125174 if self._shape is None:
126175 if self.exporting_obj is not None:
127176 if self.dl_tensor != NULL:
@@ -136,7 +185,7 @@ cdef class StridedMemoryView:
136185 return self._shape
137186
138187 @property
139- def strides(self ) -> Optional[tuple[int]]:
188+ def strides(self ) -> Optional[tuple[int , ... ]]:
140189 cdef int itemsize
141190 if self._strides_init is False:
142191 if self.exporting_obj is not None:
@@ -193,6 +242,7 @@ cdef str get_simple_repr(obj):
193242 return obj_repr
194243
195244
245+
196246cdef bint check_has_dlpack(obj) except * :
197247 cdef bint has_dlpack
198248 if hasattr (obj, " __dlpack__" ) and hasattr (obj, " __dlpack_device__" ):
@@ -206,8 +256,7 @@ cdef bint check_has_dlpack(obj) except*:
206256
207257
208258cdef class _StridedMemoryViewProxy:
209-
210- cdef:
259+ cdef readonly:
211260 object obj
212261 bint has_dlpack
213262
@@ -217,9 +266,9 @@ cdef class _StridedMemoryViewProxy:
217266
218267 cpdef StridedMemoryView view(self , stream_ptr = None ):
219268 if self .has_dlpack:
220- return view_as_dlpack (self .obj, stream_ptr)
269+ return StridedMemoryView.from_dlpack (self .obj, stream_ptr)
221270 else :
222- return view_as_cai (self .obj, stream_ptr)
271+ return StridedMemoryView.from_cuda_array_interface (self .obj, stream_ptr)
223272
224273
225274cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
@@ -354,7 +403,6 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
354403 return numpy.dtype(np_dtype)
355404
356405
357- # Also generate for Python so we can test this code path
358406cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
359407 cdef dict cai_data = obj.__cuda_array_interface__
360408 if cai_data[" version" ] < 3 :
0 commit comments