Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions cuda_core/cuda/core/_include/aoti_shim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI.
* Original: torch/csrc/inductor/aoti_torch/c/shim.h
*
* These are declarations only -- no definitions are provided. The actual
* symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL)
* and resolved at runtime by the dynamic linker. This means PyTorch is
* NOT required at compile time.
*/

#ifndef CUDA_CORE_AOTI_SHIM_H
#define CUDA_CORE_AOTI_SHIM_H

#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef int32_t AOTITorchError;

/* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */
struct AtenTensorOpaque;
typedef struct AtenTensorOpaque* AtenTensorHandle;

/* ---- tensor metadata --------------------------------------------------- */

AOTITorchError aoti_torch_get_data_ptr(
AtenTensorHandle tensor, void** ret_data_ptr);

AOTITorchError aoti_torch_get_dim(
AtenTensorHandle tensor, int64_t* ret_dim);

AOTITorchError aoti_torch_get_sizes(
AtenTensorHandle tensor, int64_t** ret_sizes);

AOTITorchError aoti_torch_get_strides(
AtenTensorHandle tensor, int64_t** ret_strides);

/* ---- dtype ------------------------------------------------------------- */

AOTITorchError aoti_torch_get_dtype(
AtenTensorHandle tensor, int32_t* ret_dtype);

int32_t aoti_torch_dtype_float16(void);
int32_t aoti_torch_dtype_float32(void);
int32_t aoti_torch_dtype_float64(void);
int32_t aoti_torch_dtype_bfloat16(void);
int32_t aoti_torch_dtype_uint8(void);
int32_t aoti_torch_dtype_int8(void);
int32_t aoti_torch_dtype_int16(void);
int32_t aoti_torch_dtype_int32(void);
int32_t aoti_torch_dtype_int64(void);
int32_t aoti_torch_dtype_bool(void);
int32_t aoti_torch_dtype_complex32(void);
int32_t aoti_torch_dtype_complex64(void);
int32_t aoti_torch_dtype_complex128(void);

/* ---- device ------------------------------------------------------------ */

AOTITorchError aoti_torch_get_device_type(
AtenTensorHandle tensor, int32_t* ret_device_type);

AOTITorchError aoti_torch_get_device_index(
AtenTensorHandle tensor, int32_t* ret_device_index);

int32_t aoti_torch_device_type_cpu(void);
int32_t aoti_torch_device_type_cuda(void);

/* ---- stream -------------------------------------------------------------- */

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index, void** ret_stream);

#ifdef __cplusplus
} /* extern "C" */
#endif

#endif /* CUDA_CORE_AOTI_SHIM_H */
94 changes: 94 additions & 0 deletions cuda_core/cuda/core/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t
from cuda.core._layout cimport _StridedLayout, get_strides_ptr
from cuda.core._stream import Stream

import ctypes
import functools
import sys
import warnings

import numpy
Expand All @@ -29,6 +31,73 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core._memory import Buffer


# ---------------------------------------------------------------------------
# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used)
# ---------------------------------------------------------------------------

cdef object _tensor_bridge = None
# Cache: type(obj) -> True/False for the torch tensor check.
# Once a type is seen, we never re-check.
cdef dict _torch_type_cache = {}
# Tri-state: None = not checked, True/False = result of version check
cdef object _torch_version_ok = None

cdef inline bint _torch_version_check():
"""Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized.

Lower bound: AOTI functions we use were introduced in PyTorch 2.3.
Upper bound: the ``pyobj_to_aten_handle`` trick relies on the
THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata)
and the identity ``AtenTensorHandle == at::Tensor*``. Both are
undocumented internals that could change in a future PyTorch version.
We cap at the latest version we have tested against; unknown versions
fall back to the standard DLPack/CAI paths. Bump the upper bound
after verifying a new PyTorch release.
"""
global _torch_version_ok
if _torch_version_ok is not None:
return <bint>_torch_version_ok
torch = sys.modules.get("torch")
if torch is None:
_torch_version_ok = False
return False
try:
major, minor = int(torch.__version__.split(".")[0]), \
int(torch.__version__.split(".")[1])
_torch_version_ok = (2, 3) <= (major, minor) <= (2, 11)
except (ValueError, IndexError):
_torch_version_ok = False
return <bint>_torch_version_ok


cdef inline bint _is_torch_tensor(object obj):
cdef type tp = type(obj)
cdef object cached = _torch_type_cache.get(tp)
if cached is not None:
return <bint>cached
cdef str mod = tp.__module__ or ""
cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \
and _torch_version_check()
_torch_type_cache[tp] = result
return result


cdef object _get_tensor_bridge():
"""Bootstrap AOTI symbols, then import _tensor_bridge on first use."""
global _tensor_bridge
if _tensor_bridge is not None:
return _tensor_bridge
torch_C = sys.modules.get("torch._C")
if torch_C is None:
raise RuntimeError(
"torch._C is not loaded; cannot initialise the tensor bridge. "
"Make sure PyTorch is imported before passing a torch.Tensor.")
ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL)
from cuda.core import _tensor_bridge as tb
_tensor_bridge = tb
return _tensor_bridge


try:
from ml_dtypes import bfloat16
except ImportError:
Expand Down Expand Up @@ -150,6 +219,9 @@ cdef class StridedMemoryView:
Stream pointer for synchronization. If ``None``, no synchronization is performed.
"""
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
if _is_torch_tensor(obj):
_get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
return buf
view_as_dlpack(obj, stream_ptr, buf)
return buf

Expand All @@ -165,6 +237,9 @@ cdef class StridedMemoryView:
Stream pointer for synchronization. If ``None``, no synchronization is performed.
"""
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
if _is_torch_tensor(obj):
_get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
return buf
view_as_cai(obj, stream_ptr, buf)
return buf

Expand All @@ -178,6 +253,9 @@ cdef class StridedMemoryView:
An object implementing the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol (e.g., a numpy array).
"""
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
if _is_torch_tensor(obj):
_get_tensor_bridge().view_as_torch_tensor(obj, None, buf)
return buf
view_as_array_interface(obj, buf)
return buf

Expand All @@ -187,6 +265,8 @@ cdef class StridedMemoryView:

Tries `DLPack <https://dmlc.github.io/dlpack/latest/>`_ first, then falls back to
`__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
``torch.Tensor`` objects are transparently handled via a fast AOTI path
regardless of which protocol is selected.

Parameters
----------
Expand Down Expand Up @@ -480,6 +560,10 @@ cdef class StridedMemoryView:
if self._dtype is None:
if self.dl_tensor != NULL:
self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype)
elif isinstance(self.metadata, int):
# AOTI dtype code stored by the torch tensor bridge
self._dtype = _get_tensor_bridge().resolve_aoti_dtype(
self.metadata)
elif self.metadata is not None:
self._dtype = _typestr2dtype(self.metadata["typestr"])
return self._dtype
Expand Down Expand Up @@ -1122,6 +1206,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
as_cu(h_event), <cydriver.CUstream>producer_s))
HANDLE_RETURN(cydriver.cuStreamWaitEvent(
<cydriver.CUstream>consumer_s, as_cu(h_event), 0))
elif _is_torch_tensor(obj):
# PyTorch's __cuda_array_interface__ reports version 2 and
# omits the "stream" field, so the standard CAI sync path
# above is a no-op for torch tensors. This is unsafe: the
# consumer has no guarantee that the producer's work is
# visible. We fix this by querying PyTorch's current CUDA
# stream via the AOTI stable C ABI and performing the same
# event-based stream ordering.
_get_tensor_bridge().sync_torch_stream(
buf.device_id, <intptr_t>(stream_ptr))

return buf

Expand Down
Loading