From 82ad5980bd3cb970e73ab4d63755485edbf1338c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 07:15:02 +0000 Subject: [PATCH 01/17] Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge Provide a fast path for constructing a StridedMemoryView from a torch.Tensor by reading tensor metadata directly through PyTorch's AOT Inductor (AOTI) stable C ABI, avoiding DLPack/CAI protocol overhead (~10 ns per tensor via pointer arithmetic). Key design: - Vendored AOTI shim header (aoti_shim.h) with extern "C" wrapping - _tensor_bridge.pyx loaded lazily (only when a torch.Tensor is first passed) to avoid undefined AOTI symbols at import time - RTLD_GLOBAL bootstrap via sys.modules["torch._C"] before loading _tensor_bridge.so - torch detection via type(obj).__module__.startswith("torch") - PyTorch is NOT a build-time or run-time dependency of cuda.core Closes NVIDIA/cuda-python#749 Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_include/aoti_shim.h | 84 ++++++++ cuda_core/cuda/core/_memoryview.pyx | 64 +++++- cuda_core/cuda/core/_tensor_bridge.pxd | 3 + cuda_core/cuda/core/_tensor_bridge.pyx | 238 +++++++++++++++++++++++ cuda_core/tests/test_utils.py | 131 +++++++++++++ 5 files changed, 514 insertions(+), 6 deletions(-) create mode 100644 cuda_core/cuda/core/_include/aoti_shim.h create mode 100644 cuda_core/cuda/core/_tensor_bridge.pxd create mode 100644 cuda_core/cuda/core/_tensor_bridge.pyx diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h new file mode 100644 index 0000000000..4a9343d241 --- /dev/null +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI. + * Original: torch/csrc/inductor/aoti_torch/c/shim.h + * + * These are declarations only -- no definitions are provided. The actual + * symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL) + * and resolved at runtime by the dynamic linker. This means PyTorch is + * NOT required at compile time. + */ + +#ifndef CUDA_CORE_AOTI_SHIM_H +#define CUDA_CORE_AOTI_SHIM_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef int32_t AOTITorchError; + +/* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */ +struct AtenTensorOpaque; +typedef struct AtenTensorOpaque* AtenTensorHandle; + +/* ---- tensor metadata --------------------------------------------------- */ + +AOTITorchError aoti_torch_get_data_ptr( + AtenTensorHandle tensor, void** ret_data_ptr); + +AOTITorchError aoti_torch_get_dim( + AtenTensorHandle tensor, int64_t* ret_dim); + +AOTITorchError aoti_torch_get_numel( + AtenTensorHandle tensor, int64_t* ret_numel); + +AOTITorchError aoti_torch_get_sizes( + AtenTensorHandle tensor, int64_t** ret_sizes); + +AOTITorchError aoti_torch_get_strides( + AtenTensorHandle tensor, int64_t** ret_strides); + +AOTITorchError aoti_torch_get_storage_offset( + AtenTensorHandle tensor, int64_t* ret_storage_offset); + +/* ---- dtype ------------------------------------------------------------- */ + +AOTITorchError aoti_torch_get_dtype( + AtenTensorHandle tensor, int32_t* ret_dtype); + +int32_t aoti_torch_dtype_float16(void); +int32_t aoti_torch_dtype_float32(void); +int32_t aoti_torch_dtype_float64(void); +int32_t aoti_torch_dtype_bfloat16(void); +int32_t aoti_torch_dtype_uint8(void); +int32_t aoti_torch_dtype_int8(void); +int32_t aoti_torch_dtype_int16(void); +int32_t aoti_torch_dtype_int32(void); +int32_t aoti_torch_dtype_int64(void); +int32_t aoti_torch_dtype_bool(void); +int32_t aoti_torch_dtype_complex32(void); +int32_t aoti_torch_dtype_complex64(void); +int32_t aoti_torch_dtype_complex128(void); + +/* ---- device ------------------------------------------------------------ */ + +AOTITorchError aoti_torch_get_device_type( + AtenTensorHandle tensor, int32_t* ret_device_type); + +AOTITorchError aoti_torch_get_device_index( + AtenTensorHandle tensor, int32_t* ret_device_index); + +int32_t aoti_torch_device_type_cpu(void); +int32_t aoti_torch_device_type_cuda(void); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* CUDA_CORE_AOTI_SHIM_H */ diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index e0439ef23c..24533eb6e3 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -29,6 +29,35 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._memory import Buffer +# --------------------------------------------------------------------------- +# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used) +# --------------------------------------------------------------------------- + +cdef object _tensor_bridge = None + + +cdef inline bint _is_torch_tensor(object obj): + cdef str mod = type(obj).__module__ or "" + return mod.startswith("torch") and hasattr(obj, "data_ptr") + + +cdef object _get_tensor_bridge(): + """Bootstrap AOTI symbols, then import _tensor_bridge on first use.""" + global _tensor_bridge + if _tensor_bridge is not None: + return _tensor_bridge + import ctypes, sys + torch_C = sys.modules.get("torch._C") + if torch_C is None: + raise RuntimeError( + "torch._C is not loaded; cannot initialise the tensor bridge. " + "Make sure PyTorch is imported before passing a torch.Tensor.") + ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL) + from cuda.core import _tensor_bridge as tb + _tensor_bridge = tb + return _tensor_bridge + + try: from ml_dtypes import bfloat16 except ImportError: @@ -112,7 +141,15 @@ cdef class StridedMemoryView: cdef str clsname = self.__class__.__name__ if obj is not None: # populate self's attributes - if check_has_dlpack(obj): + if _is_torch_tensor(obj): + warnings.warn( + f"Constructing a {clsname} directly from a torch.Tensor is deprecated; " + "Use `StridedMemoryView.from_any_interface` instead.", + DeprecationWarning, + stacklevel=2, + ) + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, self) + elif check_has_dlpack(obj): warnings.warn( f"Constructing a {clsname} directly from a DLPack-supporting object is deprecated; " "Use `StridedMemoryView.from_dlpack` or `StridedMemoryView.from_any_interface` instead.", @@ -185,17 +222,24 @@ cdef class StridedMemoryView: def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView: """Create a view by automatically selecting the best available protocol. - Tries `DLPack `_ first, then falls back to + For ``torch.Tensor`` objects a fast path via the AOTI stable C ABI is + used (no DLPack/CAI overhead). Otherwise tries + `DLPack `_ first, then falls back to `__cuda_array_interface__ `_. Parameters ---------- obj : object - An object implementing `DLPack `_ or - `__cuda_array_interface__ `_. + An object implementing `DLPack `_, + `__cuda_array_interface__ `_, + or a ``torch.Tensor``. stream_ptr : int, optional Stream pointer for synchronization. If ``None``, no synchronization is performed. """ + if _is_torch_tensor(obj): + buf = StridedMemoryView.__new__(cls) + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf if check_has_dlpack(obj): return cls.from_dlpack(obj, stream_ptr) return cls.from_cuda_array_interface(obj, stream_ptr) @@ -921,13 +965,21 @@ cdef class _StridedMemoryViewProxy: cdef readonly: object obj bint has_dlpack + cdef: + bint _is_torch def __init__(self, obj): self.obj = obj - self.has_dlpack = check_has_dlpack(obj) + self._is_torch = _is_torch_tensor(obj) + if not self._is_torch: + self.has_dlpack = check_has_dlpack(obj) + else: + self.has_dlpack = False cpdef StridedMemoryView view(self, stream_ptr=None): - if self.has_dlpack: + if self._is_torch: + return _get_tensor_bridge().view_as_torch_tensor(self.obj, stream_ptr) + elif self.has_dlpack: return StridedMemoryView.from_dlpack(self.obj, stream_ptr) else: return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr) diff --git a/cuda_core/cuda/core/_tensor_bridge.pxd b/cuda_core/cuda/core/_tensor_bridge.pxd new file mode 100644 index 0000000000..a555d95123 --- /dev/null +++ b/cuda_core/cuda/core/_tensor_bridge.pxd @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx new file mode 100644 index 0000000000..6dba91f335 --- /dev/null +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tensor bridge: extract PyTorch tensor metadata via the AOTI stable C ABI. + +PyTorch is NOT required at build time. At runtime the AOTI symbols are +resolved from ``torch._C`` (which is loaded with ``RTLD_GLOBAL``). + +The ``pyobj_to_aten_handle`` trick exploits the internal layout of +``THPVariable`` (PyTorch's Python tensor wrapper):: + + struct THPVariable { + PyObject_HEAD + MaybeOwned cdata; // <-- this IS the AtenTensorHandle + }; + +Offsetting past ``PyObject_HEAD`` gives us the ``at::Tensor`` pointer +without any Python attribute access or method calls (~10 ns per tensor). + +Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. + +.. note:: + + This module must NOT be imported at ``cuda.core`` load time. It is + loaded lazily (by ``_memoryview.pyx``) only when the user actually + passes a ``torch.Tensor``. The caller must ensure that + ``torch._C`` has been re-opened with ``RTLD_GLOBAL`` *before* + importing this module so that the AOTI symbols are visible. +""" + +from libc.stdint cimport intptr_t, int32_t, int64_t + +from cuda.core._memoryview cimport StridedMemoryView +from cuda.core._layout cimport _StridedLayout + +cdef extern from "Python.h": + ctypedef struct PyObject: + pass + +cdef extern from "_include/aoti_shim.h": + ctypedef int32_t AOTITorchError + + ctypedef struct AtenTensorOpaque: + pass + ctypedef AtenTensorOpaque* AtenTensorHandle + + # tensor metadata + AOTITorchError aoti_torch_get_data_ptr(AtenTensorHandle, void**) + AOTITorchError aoti_torch_get_dim(AtenTensorHandle, int64_t*) + AOTITorchError aoti_torch_get_sizes(AtenTensorHandle, int64_t**) + AOTITorchError aoti_torch_get_strides(AtenTensorHandle, int64_t**) + AOTITorchError aoti_torch_get_storage_offset(AtenTensorHandle, int64_t*) + + # dtype + AOTITorchError aoti_torch_get_dtype(AtenTensorHandle, int32_t*) + int32_t aoti_torch_dtype_float16() + int32_t aoti_torch_dtype_float32() + int32_t aoti_torch_dtype_float64() + int32_t aoti_torch_dtype_bfloat16() + int32_t aoti_torch_dtype_uint8() + int32_t aoti_torch_dtype_int8() + int32_t aoti_torch_dtype_int16() + int32_t aoti_torch_dtype_int32() + int32_t aoti_torch_dtype_int64() + int32_t aoti_torch_dtype_bool() + int32_t aoti_torch_dtype_complex32() + int32_t aoti_torch_dtype_complex64() + int32_t aoti_torch_dtype_complex128() + + # device + AOTITorchError aoti_torch_get_device_type(AtenTensorHandle, int32_t*) + AOTITorchError aoti_torch_get_device_index(AtenTensorHandle, int32_t*) + int32_t aoti_torch_device_type_cpu() + int32_t aoti_torch_device_type_cuda() + +import numpy + + +# --------------------------------------------------------------------------- +# Module-level state (initialised at import time — AOTI symbols are +# guaranteed visible because _memoryview bootstraps RTLD_GLOBAL before +# importing us) +# --------------------------------------------------------------------------- + +cdef int32_t _DEVICE_TYPE_CPU = aoti_torch_device_type_cpu() +cdef int32_t _DEVICE_TYPE_CUDA = aoti_torch_device_type_cuda() +cdef dict _aoti_dtype_map = None + + +# --------------------------------------------------------------------------- +# pointer extraction +# --------------------------------------------------------------------------- + +cdef inline AtenTensorHandle pyobj_to_aten_handle(object obj): + """Extract AtenTensorHandle by offsetting past PyObject_HEAD.""" + return (obj + sizeof(PyObject)) + + +# --------------------------------------------------------------------------- +# dtype mapping (AOTI int32 -> numpy dtype) +# --------------------------------------------------------------------------- + +cdef dict _build_dtype_map(): + try: + from ml_dtypes import bfloat16 as _bf16 + has_bfloat16 = True + except ImportError: + has_bfloat16 = False + + cdef dict m = { + aoti_torch_dtype_float16(): numpy.dtype(numpy.float16), + aoti_torch_dtype_float32(): numpy.dtype(numpy.float32), + aoti_torch_dtype_float64(): numpy.dtype(numpy.float64), + aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8), + aoti_torch_dtype_int8(): numpy.dtype(numpy.int8), + aoti_torch_dtype_int16(): numpy.dtype(numpy.int16), + aoti_torch_dtype_int32(): numpy.dtype(numpy.int32), + aoti_torch_dtype_int64(): numpy.dtype(numpy.int64), + aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_), + aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64), + aoti_torch_dtype_complex128(): numpy.dtype(numpy.complex128), + } + if has_bfloat16: + m[aoti_torch_dtype_bfloat16()] = numpy.dtype("bfloat16") + return m + + +cdef object _get_aoti_dtype(int32_t dtype_code): + global _aoti_dtype_map + if _aoti_dtype_map is None: + _aoti_dtype_map = _build_dtype_map() + result = _aoti_dtype_map.get(dtype_code) + if result is None: + raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") + return result + + +# --------------------------------------------------------------------------- +# Public API: construct StridedMemoryView from a torch.Tensor +# --------------------------------------------------------------------------- + +def view_as_torch_tensor(object obj, object stream_ptr, view=None): + """Create/populate a :class:`StridedMemoryView` from a ``torch.Tensor``. + + This is a fast path that avoids DLPack/CAI protocol overhead by + reading tensor metadata directly through the AOTI stable C ABI. + + Parameters + ---------- + obj : torch.Tensor + The source tensor. + stream_ptr : int or None + Consumer stream pointer (currently unused — no stream ordering + is performed for torch tensors). + view : StridedMemoryView, optional + If provided, populate this existing view in-place. Otherwise a + new instance is created. + """ + cdef AtenTensorHandle handle = pyobj_to_aten_handle(obj) + cdef AOTITorchError err + + # -- data pointer -- + cdef void* data_ptr + err = aoti_torch_get_data_ptr(handle, &data_ptr) + if err != 0: + raise RuntimeError("aoti_torch_get_data_ptr failed") + + # -- ndim -- + cdef int64_t ndim + err = aoti_torch_get_dim(handle, &ndim) + if err != 0: + raise RuntimeError("aoti_torch_get_dim failed") + + # -- shape / strides (borrowed pointers, valid while obj alive) -- + cdef int64_t* sizes_ptr + cdef int64_t* strides_ptr + err = aoti_torch_get_sizes(handle, &sizes_ptr) + if err != 0: + raise RuntimeError("aoti_torch_get_sizes failed") + err = aoti_torch_get_strides(handle, &strides_ptr) + if err != 0: + raise RuntimeError("aoti_torch_get_strides failed") + + # -- dtype -- + cdef int32_t dtype_code + err = aoti_torch_get_dtype(handle, &dtype_code) + if err != 0: + raise RuntimeError("aoti_torch_get_dtype failed") + + # -- device -- + cdef int32_t device_type, device_index + err = aoti_torch_get_device_type(handle, &device_type) + if err != 0: + raise RuntimeError("aoti_torch_get_device_type failed") + err = aoti_torch_get_device_index(handle, &device_index) + if err != 0: + raise RuntimeError("aoti_torch_get_device_index failed") + + # -- populate StridedMemoryView -- + cdef StridedMemoryView buf + if view is not None: + buf = view + else: + buf = StridedMemoryView.__new__(StridedMemoryView) + + buf.ptr = data_ptr + buf.readonly = False + buf.exporting_obj = obj + buf.dl_tensor = NULL + buf.metadata = None + buf._buffer = None + + if device_type == _DEVICE_TYPE_CPU: + buf.device_id = -1 + buf.is_device_accessible = False + elif device_type == _DEVICE_TYPE_CUDA: + buf.device_id = device_index + buf.is_device_accessible = True + else: + raise BufferError( + f"Unsupported device type from torch tensor " + f"(AOTI device type id: {device_type})") + + buf._dtype = _get_aoti_dtype(dtype_code) + + # Build _StridedLayout. init_from_ptr copies shape/strides so we are + # safe even though they are borrowed pointers. + cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) + layout.init_from_ptr( + ndim, + sizes_ptr, + strides_ptr, + buf._dtype.itemsize, + ) + buf._layout = layout + + return buf diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 59829f8fb3..146a3f303f 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -712,3 +712,134 @@ def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, a smv = api(a, stream_ptr=0) with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"): smv.dtype # noqa: B018 + + +# =================================================================== +# Tensor bridge (torch.Tensor fast path via AOTI stable C ABI) +# =================================================================== + +_torch_skip = pytest.mark.skipif(torch is None, reason="PyTorch is not installed") + + +@_torch_skip +@pytest.mark.parametrize( + "dtype", + [ + pytest.param("float16", id="float16"), + pytest.param("float32", id="float32"), + pytest.param("float64", id="float64"), + pytest.param("int8", id="int8"), + pytest.param("int16", id="int16"), + pytest.param("int32", id="int32"), + pytest.param("int64", id="int64"), + pytest.param("uint8", id="uint8"), + pytest.param("bool", id="bool"), + pytest.param("complex64", id="complex64"), + pytest.param("complex128", id="complex128"), + ], +) +def test_torch_tensor_bridge_dtypes(init_cuda, dtype): + """Verify that dtype mapping via the tensor bridge matches torch's own dtype.""" + torch_dtype = getattr(torch, dtype) + a = torch.tensor([1, 0, 1], dtype=torch_dtype, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.dtype.itemsize == a.element_size() + assert smv.ptr == a.data_ptr() + + +@_torch_skip +@pytest.mark.skipif(ml_dtypes is None, reason="ml_dtypes is not installed") +def test_torch_tensor_bridge_bfloat16(init_cuda): + a = torch.tensor([1, 2, 3], dtype=torch.bfloat16, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.dtype == np.dtype("bfloat16") + assert smv.ptr == a.data_ptr() + + +@_torch_skip +def test_torch_tensor_bridge_cuda_1d(init_cuda): + a = torch.arange(12, dtype=torch.float32, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (12,) + assert smv.strides in (None, (1,)) # C-contiguous may be None + assert smv.dtype == np.dtype(np.float32) + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + assert smv.readonly is False + assert smv.exporting_obj is a + + +@_torch_skip +def test_torch_tensor_bridge_cuda_nd(init_cuda): + a = torch.arange(24, dtype=torch.float32, device="cuda").reshape(2, 3, 4) + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (2, 3, 4) + assert smv.dtype == np.dtype(np.float32) + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + + +@_torch_skip +def test_torch_tensor_bridge_non_contiguous(init_cuda): + """Transposed tensor should have non-trivial strides.""" + a = torch.arange(12, dtype=torch.float32, device="cuda").reshape(3, 4).t() + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.shape == (4, 3) + # torch.stride() returns element counts, same as StridedMemoryView + assert smv.strides == tuple(a.stride()) + assert smv.ptr == a.data_ptr() + + +@_torch_skip +def test_torch_tensor_bridge_sliced(init_cuda): + """Sliced tensor should have correct data_ptr (accounts for storage offset).""" + base = torch.arange(100, dtype=torch.int64, device="cuda") + a = base[10:20] + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (10,) + assert smv.dtype == np.dtype(np.int64) + + +@_torch_skip +def test_torch_tensor_bridge_scalar(init_cuda): + a = torch.tensor(42.0, dtype=torch.float32, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == () + assert smv.dtype == np.dtype(np.float32) + + +@_torch_skip +def test_torch_tensor_bridge_empty(init_cuda): + a = torch.empty(0, dtype=torch.float32, device="cuda") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.shape == (0,) + assert smv.dtype == np.dtype(np.float32) + + +@_torch_skip +def test_torch_tensor_bridge_cpu(init_cuda): + a = torch.arange(5, dtype=torch.float32, device="cpu") + smv = StridedMemoryView.from_any_interface(a, stream_ptr=-1) + assert smv.ptr == a.data_ptr() + assert smv.shape == (5,) + assert smv.device_id == -1 + assert smv.is_device_accessible is False + + +@_torch_skip +def test_torch_tensor_bridge_decorator(init_cuda): + """Verify tensor bridge works through the args_viewable_as_strided_memory decorator.""" + @args_viewable_as_strided_memory((0,)) + def fn(tensor, stream): + return tensor.view(stream.handle) + + a = torch.arange(6, dtype=torch.float32, device="cuda").reshape(2, 3) + stream = Device().create_stream() + smv = fn(a, stream) + assert smv.ptr == a.data_ptr() + assert smv.shape == (2, 3) + assert smv.dtype == np.dtype(np.float32) From f8f8d8c614750c475693053096345ec1a2b8b4f4 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 18:03:46 +0000 Subject: [PATCH 02/17] Clean up tensor bridge: remove unused AOTI decls, lazy dtype, drop empty .pxd - Remove unused aoti_torch_get_numel and aoti_torch_get_storage_offset declarations from aoti_shim.h and _tensor_bridge.pyx - Fix license headers on new files to 2026 (not 2024-2026) - Delete empty _tensor_bridge.pxd (nothing cimports from it) - Defer numpy dtype resolution for torch tensors: store raw AOTI dtype code in metadata, compute itemsize from a cheap lookup table, and only resolve the full numpy dtype on first .dtype access via get_dtype() Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_include/aoti_shim.h | 8 +---- cuda_core/cuda/core/_memoryview.pyx | 4 +++ cuda_core/cuda/core/_tensor_bridge.pxd | 3 -- cuda_core/cuda/core/_tensor_bridge.pyx | 46 +++++++++++++++++++++--- 4 files changed, 46 insertions(+), 15 deletions(-) delete mode 100644 cuda_core/cuda/core/_tensor_bridge.pxd diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h index 4a9343d241..7a79079135 100644 --- a/cuda_core/cuda/core/_include/aoti_shim.h +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: Apache-2.0 * @@ -35,18 +35,12 @@ AOTITorchError aoti_torch_get_data_ptr( AOTITorchError aoti_torch_get_dim( AtenTensorHandle tensor, int64_t* ret_dim); -AOTITorchError aoti_torch_get_numel( - AtenTensorHandle tensor, int64_t* ret_numel); - AOTITorchError aoti_torch_get_sizes( AtenTensorHandle tensor, int64_t** ret_sizes); AOTITorchError aoti_torch_get_strides( AtenTensorHandle tensor, int64_t** ret_strides); -AOTITorchError aoti_torch_get_storage_offset( - AtenTensorHandle tensor, int64_t* ret_storage_offset); - /* ---- dtype ------------------------------------------------------------- */ AOTITorchError aoti_torch_get_dtype( diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index 24533eb6e3..c75458818f 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -524,6 +524,10 @@ cdef class StridedMemoryView: if self._dtype is None: if self.dl_tensor != NULL: self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) + elif isinstance(self.metadata, int): + # AOTI dtype code stored by the torch tensor bridge + self._dtype = _get_tensor_bridge().resolve_aoti_dtype( + self.metadata) elif self.metadata is not None: self._dtype = _typestr2dtype(self.metadata["typestr"]) return self._dtype diff --git a/cuda_core/cuda/core/_tensor_bridge.pxd b/cuda_core/cuda/core/_tensor_bridge.pxd deleted file mode 100644 index a555d95123..0000000000 --- a/cuda_core/cuda/core/_tensor_bridge.pxd +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 6dba91f335..a9599c6582 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 @@ -29,7 +29,7 @@ Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. importing this module so that the AOTI symbols are visible. """ -from libc.stdint cimport intptr_t, int32_t, int64_t +from libc.stdint cimport intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t from cuda.core._memoryview cimport StridedMemoryView from cuda.core._layout cimport _StridedLayout @@ -50,7 +50,6 @@ cdef extern from "_include/aoti_shim.h": AOTITorchError aoti_torch_get_dim(AtenTensorHandle, int64_t*) AOTITorchError aoti_torch_get_sizes(AtenTensorHandle, int64_t**) AOTITorchError aoti_torch_get_strides(AtenTensorHandle, int64_t**) - AOTITorchError aoti_torch_get_storage_offset(AtenTensorHandle, int64_t*) # dtype AOTITorchError aoti_torch_get_dtype(AtenTensorHandle, int32_t*) @@ -86,6 +85,7 @@ import numpy cdef int32_t _DEVICE_TYPE_CPU = aoti_torch_device_type_cpu() cdef int32_t _DEVICE_TYPE_CUDA = aoti_torch_device_type_cuda() cdef dict _aoti_dtype_map = None +cdef dict _aoti_itemsize_map = None # --------------------------------------------------------------------------- @@ -136,6 +136,39 @@ cdef object _get_aoti_dtype(int32_t dtype_code): return result +def resolve_aoti_dtype(int32_t dtype_code): + """Python-callable wrapper around _get_aoti_dtype (for lazy resolution).""" + return _get_aoti_dtype(dtype_code) + + +cdef dict _build_itemsize_map(): + return { + aoti_torch_dtype_bool(): sizeof(uint8_t), + aoti_torch_dtype_uint8(): sizeof(uint8_t), + aoti_torch_dtype_int8(): sizeof(int8_t), + aoti_torch_dtype_float16(): sizeof(int16_t), # no C float16 + aoti_torch_dtype_bfloat16(): sizeof(int16_t), # no C bfloat16 + aoti_torch_dtype_int16(): sizeof(int16_t), + aoti_torch_dtype_complex32(): 2 * sizeof(int16_t), # no C complex32 + aoti_torch_dtype_float32(): sizeof(float), + aoti_torch_dtype_int32(): sizeof(int32_t), + aoti_torch_dtype_complex64(): 2 * sizeof(float), + aoti_torch_dtype_float64(): sizeof(double), + aoti_torch_dtype_int64(): sizeof(int64_t), + aoti_torch_dtype_complex128(): 2 * sizeof(double), + } + + +cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: + global _aoti_itemsize_map + if _aoti_itemsize_map is None: + _aoti_itemsize_map = _build_itemsize_map() + result = _aoti_itemsize_map.get(dtype_code) + if result is None: + raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") + return result + + # --------------------------------------------------------------------------- # Public API: construct StridedMemoryView from a torch.Tensor # --------------------------------------------------------------------------- @@ -222,16 +255,19 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): f"Unsupported device type from torch tensor " f"(AOTI device type id: {device_type})") - buf._dtype = _get_aoti_dtype(dtype_code) + # Defer full numpy dtype resolution until first .dtype access. + # Store the raw AOTI dtype code in metadata for lazy lookup. + buf.metadata = dtype_code # Build _StridedLayout. init_from_ptr copies shape/strides so we are # safe even though they are borrowed pointers. + cdef int itemsize = _get_aoti_itemsize(dtype_code) cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) layout.init_from_ptr( ndim, sizes_ptr, strides_ptr, - buf._dtype.itemsize, + itemsize, ) buf._layout = layout From af06e9b539cbaa4e2d6ece7187093d7edb6660c3 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 18:54:24 +0000 Subject: [PATCH 03/17] Move torch tensor fast path into each from_* classmethod Instead of short-circuiting in __init__ and from_any_interface, add the AOTI fast path check to from_dlpack, from_cuda_array_interface, and from_array_interface. This ensures torch tensors always take the fast path regardless of which constructor the user calls. Simplify from_any_interface and _StridedMemoryViewProxy to just delegate to the from_* methods (which now handle torch internally). Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_memoryview.pyx | 46 +++++++++++------------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index c75458818f..30f9a70957 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -141,15 +141,7 @@ cdef class StridedMemoryView: cdef str clsname = self.__class__.__name__ if obj is not None: # populate self's attributes - if _is_torch_tensor(obj): - warnings.warn( - f"Constructing a {clsname} directly from a torch.Tensor is deprecated; " - "Use `StridedMemoryView.from_any_interface` instead.", - DeprecationWarning, - stacklevel=2, - ) - _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, self) - elif check_has_dlpack(obj): + if check_has_dlpack(obj): warnings.warn( f"Constructing a {clsname} directly from a DLPack-supporting object is deprecated; " "Use `StridedMemoryView.from_dlpack` or `StridedMemoryView.from_any_interface` instead.", @@ -187,6 +179,9 @@ cdef class StridedMemoryView: Stream pointer for synchronization. If ``None``, no synchronization is performed. """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf view_as_dlpack(obj, stream_ptr, buf) return buf @@ -202,6 +197,9 @@ cdef class StridedMemoryView: Stream pointer for synchronization. If ``None``, no synchronization is performed. """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) + return buf view_as_cai(obj, stream_ptr, buf) return buf @@ -215,6 +213,9 @@ cdef class StridedMemoryView: An object implementing the `__array_interface__ `_ protocol (e.g., a numpy array). """ cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) + if _is_torch_tensor(obj): + _get_tensor_bridge().view_as_torch_tensor(obj, None, buf) + return buf view_as_array_interface(obj, buf) return buf @@ -222,24 +223,19 @@ cdef class StridedMemoryView: def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView: """Create a view by automatically selecting the best available protocol. - For ``torch.Tensor`` objects a fast path via the AOTI stable C ABI is - used (no DLPack/CAI overhead). Otherwise tries - `DLPack `_ first, then falls back to + Tries `DLPack `_ first, then falls back to `__cuda_array_interface__ `_. + ``torch.Tensor`` objects are transparently handled via a fast AOTI path + regardless of which protocol is selected. Parameters ---------- obj : object - An object implementing `DLPack `_, - `__cuda_array_interface__ `_, - or a ``torch.Tensor``. + An object implementing `DLPack `_ or + `__cuda_array_interface__ `_. stream_ptr : int, optional Stream pointer for synchronization. If ``None``, no synchronization is performed. """ - if _is_torch_tensor(obj): - buf = StridedMemoryView.__new__(cls) - _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf) - return buf if check_has_dlpack(obj): return cls.from_dlpack(obj, stream_ptr) return cls.from_cuda_array_interface(obj, stream_ptr) @@ -969,21 +965,13 @@ cdef class _StridedMemoryViewProxy: cdef readonly: object obj bint has_dlpack - cdef: - bint _is_torch def __init__(self, obj): self.obj = obj - self._is_torch = _is_torch_tensor(obj) - if not self._is_torch: - self.has_dlpack = check_has_dlpack(obj) - else: - self.has_dlpack = False + self.has_dlpack = check_has_dlpack(obj) cpdef StridedMemoryView view(self, stream_ptr=None): - if self._is_torch: - return _get_tensor_bridge().view_as_torch_tensor(self.obj, stream_ptr) - elif self.has_dlpack: + if self.has_dlpack: return StridedMemoryView.from_dlpack(self.obj, stream_ptr) else: return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr) From 44be5800c3143c6044d85bee4bccc663c383d915 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 20:11:48 +0000 Subject: [PATCH 04/17] Add stream ordering for torch tensor bridge When stream_ptr is not -1, establish stream ordering between PyTorch's current CUDA stream (the producer) and the consumer stream, using the same event record + stream wait pattern as the CAI path. Uses aoti_torch_get_current_cuda_stream to get the producer stream, matching what PyTorch's own __dlpack__ does internally. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_include/aoti_shim.h | 5 +++ cuda_core/cuda/core/_tensor_bridge.pyx | 57 +++++++++++++++++++----- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h index 7a79079135..26ebd1164c 100644 --- a/cuda_core/cuda/core/_include/aoti_shim.h +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -71,6 +71,11 @@ AOTITorchError aoti_torch_get_device_index( int32_t aoti_torch_device_type_cpu(void); int32_t aoti_torch_device_type_cuda(void); +/* ---- stream -------------------------------------------------------------- */ + +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, void** ret_stream); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index a9599c6582..05c999f06f 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -33,6 +33,13 @@ from libc.stdint cimport intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t from cuda.core._memoryview cimport StridedMemoryView from cuda.core._layout cimport _StridedLayout +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport ( + EventHandle, + create_event_handle_noctx, + as_cu, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN cdef extern from "Python.h": ctypedef struct PyObject: @@ -73,6 +80,9 @@ cdef extern from "_include/aoti_shim.h": int32_t aoti_torch_device_type_cpu() int32_t aoti_torch_device_type_cuda() + # stream + AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**) + import numpy @@ -184,30 +194,39 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): obj : torch.Tensor The source tensor. stream_ptr : int or None - Consumer stream pointer (currently unused — no stream ordering - is performed for torch tensors). + Consumer stream pointer. When not ``-1``, stream ordering is + established between PyTorch's current CUDA stream (the producer) + and the consumer stream, matching the DLPack contract. view : StridedMemoryView, optional If provided, populate this existing view in-place. Otherwise a new instance is created. """ cdef AtenTensorHandle handle = pyobj_to_aten_handle(obj) cdef AOTITorchError err + cdef void* data_ptr + cdef int64_t ndim + cdef int64_t* sizes_ptr + cdef int64_t* strides_ptr + cdef int32_t dtype_code + cdef int32_t device_type, device_index + cdef StridedMemoryView buf + cdef void* producer_s + cdef intptr_t consumer_s + cdef EventHandle h_event + cdef int itemsize + cdef _StridedLayout layout # -- data pointer -- - cdef void* data_ptr err = aoti_torch_get_data_ptr(handle, &data_ptr) if err != 0: raise RuntimeError("aoti_torch_get_data_ptr failed") # -- ndim -- - cdef int64_t ndim err = aoti_torch_get_dim(handle, &ndim) if err != 0: raise RuntimeError("aoti_torch_get_dim failed") # -- shape / strides (borrowed pointers, valid while obj alive) -- - cdef int64_t* sizes_ptr - cdef int64_t* strides_ptr err = aoti_torch_get_sizes(handle, &sizes_ptr) if err != 0: raise RuntimeError("aoti_torch_get_sizes failed") @@ -216,13 +235,11 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): raise RuntimeError("aoti_torch_get_strides failed") # -- dtype -- - cdef int32_t dtype_code err = aoti_torch_get_dtype(handle, &dtype_code) if err != 0: raise RuntimeError("aoti_torch_get_dtype failed") # -- device -- - cdef int32_t device_type, device_index err = aoti_torch_get_device_type(handle, &device_type) if err != 0: raise RuntimeError("aoti_torch_get_device_type failed") @@ -231,7 +248,6 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): raise RuntimeError("aoti_torch_get_device_index failed") # -- populate StridedMemoryView -- - cdef StridedMemoryView buf if view is not None: buf = view else: @@ -250,6 +266,25 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): elif device_type == _DEVICE_TYPE_CUDA: buf.device_id = device_index buf.is_device_accessible = True + + # -- stream ordering (matches the DLPack contract) -- + if stream_ptr is not None and int(stream_ptr) != -1: + err = aoti_torch_get_current_cuda_stream(device_index, + &producer_s) + if err != 0: + raise RuntimeError( + "aoti_torch_get_current_cuda_stream failed") + consumer_s = (int(stream_ptr)) + if producer_s != consumer_s: + with nogil: + h_event = create_event_handle_noctx( + cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING) + HANDLE_RETURN(cydriver.cuEventRecord( + as_cu(h_event), + producer_s)) + HANDLE_RETURN(cydriver.cuStreamWaitEvent( + consumer_s, + as_cu(h_event), 0)) else: raise BufferError( f"Unsupported device type from torch tensor " @@ -261,8 +296,8 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): # Build _StridedLayout. init_from_ptr copies shape/strides so we are # safe even though they are borrowed pointers. - cdef int itemsize = _get_aoti_itemsize(dtype_code) - cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) + itemsize = _get_aoti_itemsize(dtype_code) + layout = _StridedLayout.__new__(_StridedLayout) layout.init_from_ptr( ndim, sizes_ptr, From 6e6b8a6a08603762ea3d3e2ab7f38b9542256ac4 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 21:06:06 +0000 Subject: [PATCH 05/17] Extract reusable sync_torch_stream and apply to CAI path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Factor out stream ordering into a cpdef sync_torch_stream() helper in _tensor_bridge.pyx, callable from both C (view_as_torch_tensor) and Python (_memoryview.pyx). Apply the same stream ordering in view_as_cai for torch tensors: PyTorch's __cuda_array_interface__ reports version 2 and omits the "stream" field, so the standard CAI sync path is a no-op — leaving the consumer with no guarantee that the producer's work is visible. We now detect torch tensors in the CAI path and query PyTorch's current CUDA stream via AOTI to establish proper ordering. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_memoryview.pyx | 10 ++++++ cuda_core/cuda/core/_tensor_bridge.pyx | 50 ++++++++++++++++---------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index 30f9a70957..c29f7ae8bf 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -1166,6 +1166,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): as_cu(h_event), producer_s)) HANDLE_RETURN(cydriver.cuStreamWaitEvent( consumer_s, as_cu(h_event), 0)) + elif _is_torch_tensor(obj): + # PyTorch's __cuda_array_interface__ reports version 2 and + # omits the "stream" field, so the standard CAI sync path + # above is a no-op for torch tensors. This is unsafe: the + # consumer has no guarantee that the producer's work is + # visible. We fix this by querying PyTorch's current CUDA + # stream via the AOTI stable C ABI and performing the same + # event-based stream ordering. + _get_tensor_bridge().sync_torch_stream( + buf.device_id, (stream_ptr)) return buf diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 05c999f06f..67a2aa3805 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -179,6 +179,36 @@ cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: return result +# --------------------------------------------------------------------------- +# Stream ordering helper +# --------------------------------------------------------------------------- + +cpdef void sync_torch_stream(int32_t device_index, + intptr_t consumer_s) except *: + """Establish stream ordering between PyTorch's current CUDA stream + and the given consumer stream. + + Records an event on PyTorch's current stream (the producer) and makes + the consumer stream wait on it. This is a no-op if both streams are + the same. + """ + cdef AOTITorchError err + cdef void* producer_s + cdef EventHandle h_event + + err = aoti_torch_get_current_cuda_stream(device_index, &producer_s) + if err != 0: + raise RuntimeError("aoti_torch_get_current_cuda_stream failed") + if producer_s != consumer_s: + with nogil: + h_event = create_event_handle_noctx( + cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING) + HANDLE_RETURN(cydriver.cuEventRecord( + as_cu(h_event), producer_s)) + HANDLE_RETURN(cydriver.cuStreamWaitEvent( + consumer_s, as_cu(h_event), 0)) + + # --------------------------------------------------------------------------- # Public API: construct StridedMemoryView from a torch.Tensor # --------------------------------------------------------------------------- @@ -210,9 +240,6 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): cdef int32_t dtype_code cdef int32_t device_type, device_index cdef StridedMemoryView buf - cdef void* producer_s - cdef intptr_t consumer_s - cdef EventHandle h_event cdef int itemsize cdef _StridedLayout layout @@ -269,22 +296,7 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): # -- stream ordering (matches the DLPack contract) -- if stream_ptr is not None and int(stream_ptr) != -1: - err = aoti_torch_get_current_cuda_stream(device_index, - &producer_s) - if err != 0: - raise RuntimeError( - "aoti_torch_get_current_cuda_stream failed") - consumer_s = (int(stream_ptr)) - if producer_s != consumer_s: - with nogil: - h_event = create_event_handle_noctx( - cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING) - HANDLE_RETURN(cydriver.cuEventRecord( - as_cu(h_event), - producer_s)) - HANDLE_RETURN(cydriver.cuStreamWaitEvent( - consumer_s, - as_cu(h_event), 0)) + sync_torch_stream(device_index, (int(stream_ptr))) else: raise BufferError( f"Unsupported device type from torch tensor " From 85caaaf230b9cb3fdef68d49221f0637f4198726 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 21:23:02 +0000 Subject: [PATCH 06/17] Nits: add check_aoti helper, size_t itemsize, 2D sliced test - Add check_aoti() inline helper to replace repetitive err/raise patterns for AOTI calls (one-liner per call) - Change itemsize type from int to size_t - Add test_torch_tensor_bridge_sliced_2d test case Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 64 +++++++++++--------------- cuda_core/tests/test_utils.py | 12 +++++ 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 67a2aa3805..ff15a2c1e4 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -30,6 +30,7 @@ Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. """ from libc.stdint cimport intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t +from libc.stddef cimport size_t from cuda.core._memoryview cimport StridedMemoryView from cuda.core._layout cimport _StridedLayout @@ -107,6 +108,12 @@ cdef inline AtenTensorHandle pyobj_to_aten_handle(object obj): return (obj + sizeof(PyObject)) +cdef inline void check_aoti(AOTITorchError err, const char* name) except *: + """Raise RuntimeError if an AOTI call returned a non-zero error code.""" + if err != 0: + raise RuntimeError(f"{name.decode()} failed") + + # --------------------------------------------------------------------------- # dtype mapping (AOTI int32 -> numpy dtype) # --------------------------------------------------------------------------- @@ -169,14 +176,14 @@ cdef dict _build_itemsize_map(): } -cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: +cdef size_t _get_aoti_itemsize(int32_t dtype_code) except 0: global _aoti_itemsize_map if _aoti_itemsize_map is None: _aoti_itemsize_map = _build_itemsize_map() result = _aoti_itemsize_map.get(dtype_code) if result is None: raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") - return result + return result # --------------------------------------------------------------------------- @@ -192,13 +199,11 @@ cpdef void sync_torch_stream(int32_t device_index, the consumer stream wait on it. This is a no-op if both streams are the same. """ - cdef AOTITorchError err cdef void* producer_s cdef EventHandle h_event - err = aoti_torch_get_current_cuda_stream(device_index, &producer_s) - if err != 0: - raise RuntimeError("aoti_torch_get_current_cuda_stream failed") + check_aoti(aoti_torch_get_current_cuda_stream(device_index, &producer_s), + b"aoti_torch_get_current_cuda_stream") if producer_s != consumer_s: with nogil: h_event = create_event_handle_noctx( @@ -232,7 +237,6 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): new instance is created. """ cdef AtenTensorHandle handle = pyobj_to_aten_handle(obj) - cdef AOTITorchError err cdef void* data_ptr cdef int64_t ndim cdef int64_t* sizes_ptr @@ -240,39 +244,23 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): cdef int32_t dtype_code cdef int32_t device_type, device_index cdef StridedMemoryView buf - cdef int itemsize + cdef size_t itemsize cdef _StridedLayout layout - # -- data pointer -- - err = aoti_torch_get_data_ptr(handle, &data_ptr) - if err != 0: - raise RuntimeError("aoti_torch_get_data_ptr failed") - - # -- ndim -- - err = aoti_torch_get_dim(handle, &ndim) - if err != 0: - raise RuntimeError("aoti_torch_get_dim failed") - - # -- shape / strides (borrowed pointers, valid while obj alive) -- - err = aoti_torch_get_sizes(handle, &sizes_ptr) - if err != 0: - raise RuntimeError("aoti_torch_get_sizes failed") - err = aoti_torch_get_strides(handle, &strides_ptr) - if err != 0: - raise RuntimeError("aoti_torch_get_strides failed") - - # -- dtype -- - err = aoti_torch_get_dtype(handle, &dtype_code) - if err != 0: - raise RuntimeError("aoti_torch_get_dtype failed") - - # -- device -- - err = aoti_torch_get_device_type(handle, &device_type) - if err != 0: - raise RuntimeError("aoti_torch_get_device_type failed") - err = aoti_torch_get_device_index(handle, &device_index) - if err != 0: - raise RuntimeError("aoti_torch_get_device_index failed") + check_aoti(aoti_torch_get_data_ptr(handle, &data_ptr), + b"aoti_torch_get_data_ptr") + check_aoti(aoti_torch_get_dim(handle, &ndim), + b"aoti_torch_get_dim") + check_aoti(aoti_torch_get_sizes(handle, &sizes_ptr), + b"aoti_torch_get_sizes") + check_aoti(aoti_torch_get_strides(handle, &strides_ptr), + b"aoti_torch_get_strides") + check_aoti(aoti_torch_get_dtype(handle, &dtype_code), + b"aoti_torch_get_dtype") + check_aoti(aoti_torch_get_device_type(handle, &device_type), + b"aoti_torch_get_device_type") + check_aoti(aoti_torch_get_device_index(handle, &device_index), + b"aoti_torch_get_device_index") # -- populate StridedMemoryView -- if view is not None: diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 146a3f303f..c868033268 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -803,6 +803,18 @@ def test_torch_tensor_bridge_sliced(init_cuda): assert smv.dtype == np.dtype(np.int64) +@_torch_skip +def test_torch_tensor_bridge_sliced_2d(init_cuda): + """2D sliced tensor should have correct data_ptr, shape, and strides.""" + base = torch.arange(60, dtype=torch.float32, device="cuda").reshape(6, 10) + a = base[1:4, 2:7] + smv = StridedMemoryView.from_any_interface(a, stream_ptr=0) + assert smv.ptr == a.data_ptr() + assert smv.shape == (3, 5) + assert smv.strides == (10, 1) # element strides + assert smv.dtype == np.dtype(np.float32) + + @_torch_skip def test_torch_tensor_bridge_scalar(init_cuda): a = torch.tensor(42.0, dtype=torch.float32, device="cuda") From 9fad471540779727d7941b29b925131eddfac96e Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 21:28:33 +0000 Subject: [PATCH 07/17] Revert itemsize to int, memoize int(stream_ptr) - Revert itemsize back to int (size_t was unnecessary for small values) - Memoize int(stream_ptr) to avoid redundant Python operator conversion Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index ff15a2c1e4..9accb281d5 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -30,7 +30,6 @@ Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. """ from libc.stdint cimport intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t -from libc.stddef cimport size_t from cuda.core._memoryview cimport StridedMemoryView from cuda.core._layout cimport _StridedLayout @@ -176,14 +175,14 @@ cdef dict _build_itemsize_map(): } -cdef size_t _get_aoti_itemsize(int32_t dtype_code) except 0: +cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: global _aoti_itemsize_map if _aoti_itemsize_map is None: _aoti_itemsize_map = _build_itemsize_map() result = _aoti_itemsize_map.get(dtype_code) if result is None: raise TypeError(f"Unsupported AOTI dtype code: {dtype_code}") - return result + return result # --------------------------------------------------------------------------- @@ -244,7 +243,8 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): cdef int32_t dtype_code cdef int32_t device_type, device_index cdef StridedMemoryView buf - cdef size_t itemsize + cdef int itemsize + cdef intptr_t _stream_ptr_int cdef _StridedLayout layout check_aoti(aoti_torch_get_data_ptr(handle, &data_ptr), @@ -283,8 +283,10 @@ def view_as_torch_tensor(object obj, object stream_ptr, view=None): buf.is_device_accessible = True # -- stream ordering (matches the DLPack contract) -- - if stream_ptr is not None and int(stream_ptr) != -1: - sync_torch_stream(device_index, (int(stream_ptr))) + if stream_ptr is not None: + _stream_ptr_int = int(stream_ptr) + if _stream_ptr_int != -1: + sync_torch_stream(device_index, _stream_ptr_int) else: raise BufferError( f"Unsupported device type from torch tensor " From cc4558aff568451d00e6f11addf4a3447a68a6c9 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 21:32:42 +0000 Subject: [PATCH 08/17] Use except?-1 instead of except* for check_aoti Better Cython 3 performance: except?-1 avoids the overhead of except* which always checks for exceptions. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 9accb281d5..480b77dac5 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -107,10 +107,11 @@ cdef inline AtenTensorHandle pyobj_to_aten_handle(object obj): return (obj + sizeof(PyObject)) -cdef inline void check_aoti(AOTITorchError err, const char* name) except *: +cdef inline int check_aoti(AOTITorchError err, const char* name) except? -1: """Raise RuntimeError if an AOTI call returned a non-zero error code.""" if err != 0: raise RuntimeError(f"{name.decode()} failed") + return 0 # --------------------------------------------------------------------------- From 5f49e7a58363b9d9a1c3cc62f4de1fa833113ccf Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 21:45:05 +0000 Subject: [PATCH 09/17] Require PyTorch >= 2.3 for tensor bridge, move imports to module level The AOTI stable C ABI functions we use (get_dim, get_dtype, get_device_type, get_device_index, get_current_cuda_stream, complex dtype constants) were all introduced in PyTorch 2.3.0. Earlier versions are missing some or all of them. _is_torch_tensor now returns False when torch < 2.3, causing a graceful fallback to the standard DLPack/CAI paths. The version check result is memoized in a module-level variable. Also move `import ctypes, sys` from _get_tensor_bridge to module level. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_memoryview.pyx | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index c29f7ae8bf..4bf6fa6014 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t from cuda.core._layout cimport _StridedLayout, get_strides_ptr from cuda.core._stream import Stream +import ctypes import functools +import sys import warnings import numpy @@ -34,11 +36,31 @@ from cuda.core._memory import Buffer # --------------------------------------------------------------------------- cdef object _tensor_bridge = None +# Tri-state: None = not checked, True/False = result of version check +cdef object _torch_version_ok = None + +cdef inline bint _torch_version_check(): + """Return True if torch >= 2.3 (AOTI ABI requirement). Memoized.""" + global _torch_version_ok + if _torch_version_ok is not None: + return _torch_version_ok + torch = sys.modules.get("torch") + if torch is None: + _torch_version_ok = False + return False + try: + major, minor = int(torch.__version__.split(".")[0]), \ + int(torch.__version__.split(".")[1]) + _torch_version_ok = (major, minor) >= (2, 3) + except (ValueError, IndexError): + _torch_version_ok = False + return _torch_version_ok cdef inline bint _is_torch_tensor(object obj): cdef str mod = type(obj).__module__ or "" - return mod.startswith("torch") and hasattr(obj, "data_ptr") + return mod.startswith("torch") and hasattr(obj, "data_ptr") \ + and _torch_version_check() cdef object _get_tensor_bridge(): @@ -46,7 +68,6 @@ cdef object _get_tensor_bridge(): global _tensor_bridge if _tensor_bridge is not None: return _tensor_bridge - import ctypes, sys torch_C = sys.modules.get("torch._C") if torch_C is None: raise RuntimeError( From b98fe710cbe83ef2119f530c455c87f47502ea64 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 21:55:48 +0000 Subject: [PATCH 10/17] Add tensor bridge entry to 1.0.0 release notes Document the AOTI-based fast path for torch.Tensor in StridedMemoryView with ~10-20x speedup and stream ordering support. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/docs/source/release/1.0.0-notes.rst | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 cuda_core/docs/source/release/1.0.0-notes.rst diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst new file mode 100644 index 0000000000..133eda1272 --- /dev/null +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -0,0 +1,35 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. currentmodule:: cuda.core + +``cuda.core`` 1.0.0 Release Notes +================================= + + +Highlights +---------- + +- TBD + + +New features +------------ + +- TBD + + +Fixes and enhancements +----------------------- + +- :class:`~utils.StridedMemoryView` now provides a fast path for ``torch.Tensor`` + objects via PyTorch's AOT Inductor (AOTI) stable C ABI. When a ``torch.Tensor`` + is passed to any ``from_*`` classmethod (``from_dlpack``, + ``from_cuda_array_interface``, ``from_array_interface``, or + ``from_any_interface``), tensor metadata is read directly from the underlying + C struct, bypassing the DLPack and CUDA Array Interface protocol overhead. + This yields ~10-20x faster ``StridedMemoryView`` construction for PyTorch + tensors. Proper CUDA stream ordering is established between PyTorch's current + stream and the consumer stream, matching the DLPack synchronization contract. + Requires PyTorch >= 2.3. + (`#749 `__) From 30ba7d5ee383e23114ca9e009c4f8ae78bbbe839 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 22:19:16 +0000 Subject: [PATCH 11/17] Update speedup range in release notes to match benchmarks Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/docs/source/release/1.0.0-notes.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst index 133eda1272..34eff57100 100644 --- a/cuda_core/docs/source/release/1.0.0-notes.rst +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -28,8 +28,8 @@ Fixes and enhancements ``from_cuda_array_interface``, ``from_array_interface``, or ``from_any_interface``), tensor metadata is read directly from the underlying C struct, bypassing the DLPack and CUDA Array Interface protocol overhead. - This yields ~10-20x faster ``StridedMemoryView`` construction for PyTorch - tensors. Proper CUDA stream ordering is established between PyTorch's current + This yields ~7-20x faster ``StridedMemoryView`` construction for PyTorch + tensors (depending on whether stream ordering is required). Proper CUDA stream ordering is established between PyTorch's current stream and the consumer stream, matching the DLPack synchronization contract. Requires PyTorch >= 2.3. (`#749 `__) From 0f5764603062d7f4b72f1af4c2bde462d1a762d3 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 22:25:37 +0000 Subject: [PATCH 12/17] Document THPVariable layout change across PyTorch versions The cdata field changed from MaybeOwned (2.3-2.9) to at::Tensor (2.10+). Both layouts are compatible with our offset trick. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 480b77dac5..d0ce5ea6b3 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -103,7 +103,14 @@ cdef dict _aoti_itemsize_map = None # --------------------------------------------------------------------------- cdef inline AtenTensorHandle pyobj_to_aten_handle(object obj): - """Extract AtenTensorHandle by offsetting past PyObject_HEAD.""" + """Extract AtenTensorHandle by offsetting past PyObject_HEAD. + + In PyTorch 2.3–2.9 the first field after PyObject_HEAD is + ``c10::MaybeOwned cdata``; from 2.10 onward it is + ``at::Tensor cdata``. In both cases the address of ``cdata`` + is usable as an ``AtenTensorHandle`` (``at::Tensor*``) for the + AOTI stable C ABI functions. + """ return (obj + sizeof(PyObject)) From 74798e7be2d460338bd931961fe74bc31b3fd172 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 9 Apr 2026 23:35:25 +0000 Subject: [PATCH 13/17] Cache type check in _is_torch_tensor for ~20% speedup Cache the result of the torch tensor type check (module + hasattr + version) keyed by type(obj). Subsequent calls for the same type are a single dict lookup (~76 ns) instead of the full check (~186 ns). Non-torch objects also benefit as the cache returns False immediately after the first miss. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_memoryview.pyx | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index 4bf6fa6014..bbb8adadc5 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -36,6 +36,9 @@ from cuda.core._memory import Buffer # --------------------------------------------------------------------------- cdef object _tensor_bridge = None +# Cache: type(obj) -> True/False for the torch tensor check. +# Once a type is seen, we never re-check. +cdef dict _torch_type_cache = {} # Tri-state: None = not checked, True/False = result of version check cdef object _torch_version_ok = None @@ -58,9 +61,15 @@ cdef inline bint _torch_version_check(): cdef inline bint _is_torch_tensor(object obj): - cdef str mod = type(obj).__module__ or "" - return mod.startswith("torch") and hasattr(obj, "data_ptr") \ + cdef type tp = type(obj) + cdef object cached = _torch_type_cache.get(tp) + if cached is not None: + return cached + cdef str mod = tp.__module__ or "" + cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \ and _torch_version_check() + _torch_type_cache[tp] = result + return result cdef object _get_tensor_bridge(): From 00b8ec91b12dd4a9fa7062a48ff9c1f323d3edc5 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 10 Apr 2026 19:26:14 +0000 Subject: [PATCH 14/17] Add upper bound to torch version check (cap at 2.11) The pyobj_to_aten_handle trick and AtenTensorHandle == at::Tensor* identity are undocumented internals that could change. Cap at the latest tested version so unknown future versions fall back to the standard DLPack/CAI paths. Bump after verifying each new release. Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_memoryview.pyx | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index bbb8adadc5..3ebde8dcff 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -43,7 +43,17 @@ cdef dict _torch_type_cache = {} cdef object _torch_version_ok = None cdef inline bint _torch_version_check(): - """Return True if torch >= 2.3 (AOTI ABI requirement). Memoized.""" + """Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized. + + Lower bound: AOTI functions we use were introduced in PyTorch 2.3. + Upper bound: the ``pyobj_to_aten_handle`` trick relies on the + THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata) + and the identity ``AtenTensorHandle == at::Tensor*``. Both are + undocumented internals that could change in a future PyTorch version. + We cap at the latest version we have tested against; unknown versions + fall back to the standard DLPack/CAI paths. Bump the upper bound + after verifying a new PyTorch release. + """ global _torch_version_ok if _torch_version_ok is not None: return _torch_version_ok @@ -54,7 +64,7 @@ cdef inline bint _torch_version_check(): try: major, minor = int(torch.__version__.split(".")[0]), \ int(torch.__version__.split(".")[1]) - _torch_version_ok = (major, minor) >= (2, 3) + _torch_version_ok = (2, 3) <= (major, minor) <= (2, 11) except (ValueError, IndexError): _torch_version_ok = False return _torch_version_ok From 0c31df19dc4adc94c2d4097b90bdf868240def1c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 10 Apr 2026 19:27:37 +0000 Subject: [PATCH 15/17] Update module docstring to document both THPVariable layouts Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index d0ce5ea6b3..d241b7e592 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -12,11 +12,18 @@ The ``pyobj_to_aten_handle`` trick exploits the internal layout of struct THPVariable { PyObject_HEAD - MaybeOwned cdata; // <-- this IS the AtenTensorHandle + at::Tensor cdata; // <-- &cdata is usable as AtenTensorHandle + ... }; -Offsetting past ``PyObject_HEAD`` gives us the ``at::Tensor`` pointer -without any Python attribute access or method calls (~10 ns per tensor). +In PyTorch 2.3–2.9 ``cdata`` was ``c10::MaybeOwned``; +from 2.10 onward it is ``at::Tensor``. In both cases ``&cdata`` +(offset ``sizeof(PyObject)`` from the start of the object) is accepted +by the AOTI stable C ABI functions as an ``AtenTensorHandle``. + +Offsetting past ``PyObject_HEAD`` gives us the handle +without any Python attribute access or method calls (~14 ns for all +7 metadata queries). Credit: Emilio Castillo (ecastillo@nvidia.com) – original tensor-bridge POC. From 8c202370eac25130b4d5588e17a242f5632ed794 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 10 Apr 2026 19:34:24 +0000 Subject: [PATCH 16/17] Use except?-1 for sync_torch_stream instead of except* Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index d241b7e592..5cb4ea10c0 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -204,8 +204,8 @@ cdef int _get_aoti_itemsize(int32_t dtype_code) except -1: # Stream ordering helper # --------------------------------------------------------------------------- -cpdef void sync_torch_stream(int32_t device_index, - intptr_t consumer_s) except *: +cpdef int sync_torch_stream(int32_t device_index, + intptr_t consumer_s) except? -1: """Establish stream ordering between PyTorch's current CUDA stream and the given consumer stream. @@ -226,6 +226,7 @@ cpdef void sync_torch_stream(int32_t device_index, as_cu(h_event), producer_s)) HANDLE_RETURN(cydriver.cuStreamWaitEvent( consumer_s, as_cu(h_event), 0)) + return 0 # --------------------------------------------------------------------------- From 8c019b95ccc7317292bb39fdcb504a5feb4c7478 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 10 Apr 2026 20:10:16 +0000 Subject: [PATCH 17/17] Fix linter errors Co-Authored-By: Emilio Castillo Co-Authored-By: Claude Opus 4.6 (1M context) --- cuda_core/cuda/core/_tensor_bridge.pyx | 24 ++++++++++++------------ cuda_core/tests/test_utils.py | 1 + 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx index 5cb4ea10c0..93c4aa47a8 100644 --- a/cuda_core/cuda/core/_tensor_bridge.pyx +++ b/cuda_core/cuda/core/_tensor_bridge.pyx @@ -134,26 +134,26 @@ cdef inline int check_aoti(AOTITorchError err, const char* name) except? -1: cdef dict _build_dtype_map(): try: - from ml_dtypes import bfloat16 as _bf16 + from ml_dtypes import bfloat16 as _bf16 # noqa: F811 has_bfloat16 = True except ImportError: has_bfloat16 = False cdef dict m = { - aoti_torch_dtype_float16(): numpy.dtype(numpy.float16), - aoti_torch_dtype_float32(): numpy.dtype(numpy.float32), - aoti_torch_dtype_float64(): numpy.dtype(numpy.float64), - aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8), - aoti_torch_dtype_int8(): numpy.dtype(numpy.int8), - aoti_torch_dtype_int16(): numpy.dtype(numpy.int16), - aoti_torch_dtype_int32(): numpy.dtype(numpy.int32), - aoti_torch_dtype_int64(): numpy.dtype(numpy.int64), - aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_), - aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64), + aoti_torch_dtype_float16(): numpy.dtype(numpy.float16), + aoti_torch_dtype_float32(): numpy.dtype(numpy.float32), + aoti_torch_dtype_float64(): numpy.dtype(numpy.float64), + aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8), + aoti_torch_dtype_int8(): numpy.dtype(numpy.int8), + aoti_torch_dtype_int16(): numpy.dtype(numpy.int16), + aoti_torch_dtype_int32(): numpy.dtype(numpy.int32), + aoti_torch_dtype_int64(): numpy.dtype(numpy.int64), + aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_), + aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64), aoti_torch_dtype_complex128(): numpy.dtype(numpy.complex128), } if has_bfloat16: - m[aoti_torch_dtype_bfloat16()] = numpy.dtype("bfloat16") + m[aoti_torch_dtype_bfloat16()] = numpy.dtype(_bf16) return m diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index c868033268..8874fe1e0a 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -845,6 +845,7 @@ def test_torch_tensor_bridge_cpu(init_cuda): @_torch_skip def test_torch_tensor_bridge_decorator(init_cuda): """Verify tensor bridge works through the args_viewable_as_strided_memory decorator.""" + @args_viewable_as_strided_memory((0,)) def fn(tensor, stream): return tensor.view(stream.handle)