Skip to content

Commit 5985ee1

Browse files
authored
Allow StridedMemoryView to be constructed from dlpacks type code (#1623)
* Add ml_dtypes.bfloat16 dlpack code parsing * add pytorch tensor tests * skip if cupy major version is less than 14 * add documentation * move ml_dtypes dependency to test-cuxx * trim doc for brevity and fix display error * skip test when cupy is not present; install ml-dtypes to ft test environment --------- Co-authored-by: Michael Wang <isVoid@users.noreply.github.com>
1 parent 9f4c750 commit 5985ee1

4 files changed

Lines changed: 130 additions & 9 deletions

File tree

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

@@ -27,6 +27,12 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
2727

2828
from cuda.core._memory import Buffer
2929

30+
31+
try:
32+
from ml_dtypes import bfloat16
33+
except ImportError:
34+
bfloat16 = None
35+
3036
# TODO(leofang): support NumPy structured dtypes
3137

3238

@@ -332,6 +338,11 @@ cdef class StridedMemoryView:
332338
def dtype(self) -> numpy.dtype | None:
333339
"""
334340
Data type of the tensor.
341+
342+
Supports standard NumPy dtypes as well as narrow data types (e.g., ``bfloat16``)
343+
when the optional `ml_dtypes <https://github.com/jax-ml/ml_dtypes>`_ package is
344+
installed. If ``ml_dtypes`` is not available and such a tensor is encountered,
345+
a :obj:`NotImplementedError` will be raised.
335346
"""
336347
return self.get_dtype()
337348

@@ -555,8 +566,13 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
555566
else:
556567
raise TypeError(f'{bits}-bit bool is not supported')
557568
elif dtype.code == kDLBfloat:
558-
# TODO(leofang): use ml_dtype.bfloat16?
559-
raise NotImplementedError('bfloat is not supported yet')
569+
if bfloat16 is not None:
570+
np_dtype = numpy.dtype("bfloat16")
571+
else:
572+
raise NotImplementedError(
573+
'Support for bfloat16 within cuda-core requires `ml_dtypes`'
574+
'to be installed.'
575+
)
560576
else:
561577
raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code))
562578

cuda_core/docs/source/interoperability.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
.. SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
.. SPDX-License-Identifier: Apache-2.0
33
44
.. currentmodule:: cuda.core
@@ -79,6 +79,16 @@ array libraries.
7979
The :attr:`~utils.StridedMemoryView.is_device_accessible` attribute can be used to check
8080
whether or not the underlying buffer can be accessed on GPU.
8181

82+
The :class:`~utils.StridedMemoryView` class supports narrow data types (e.g., ``bfloat16``) when the optional
83+
`ml_dtypes <https://github.com/jax-ml/ml_dtypes>`_ package is installed. This enables interoperability with libraries that use
84+
narrow dtype tensors, such as PyTorch with ``torch.bfloat16`` or CuPy with ``"bfloat16"`` dtype.
85+
If ``ml_dtypes`` is not available and such a tensor is encountered, a
86+
:obj:`NotImplementedError` will be raised.
87+
88+
Currently supported narrow data types:
89+
90+
* ``bfloat16``
91+
8292
.. rubric:: Footnotes
8393

8494
.. [1] https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html

cuda_core/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ cu13 = ["cuda-bindings[all]==13.*"]
5656

5757
[dependency-groups]
5858
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat"]
59-
test-cu12 = ["cuda-core[test]", "cupy-cuda12x; python_version < '3.14'", "cuda-toolkit[cudart]==12.*"] # runtime headers needed by CuPy
60-
test-cu13 = ["cuda-core[test]", "cupy-cuda13x; python_version < '3.14'", "cuda-toolkit[cudart]==13.*"] # runtime headers needed by CuPy
59+
ml-dtypes = ["ml-dtypes>=0.5.4,<0.6.0"]
60+
test-cu12 = [ {include-group = "ml-dtypes" }, "cuda-core[test]", "cupy-cuda12x; python_version < '3.14'", "cuda-toolkit[cudart]==12.*"] # runtime headers needed by CuPy
61+
test-cu13 = [ {include-group = "ml-dtypes" }, "cuda-core[test]", "cupy-cuda13x; python_version < '3.14'", "cuda-toolkit[cudart]==13.*"] # runtime headers needed by CuPy
6162
# free threaded build, cupy doesn't support free-threaded builds yet, so avoid installing it for now
6263
# TODO: cupy should support free threaded builds
63-
test-cu12-ft = ["cuda-core[test]", "cuda-toolkit[cudart]==12.*"]
64-
test-cu13-ft = ["cuda-core[test]", "cuda-toolkit[cudart]==13.*"]
64+
test-cu12-ft = [ {include-group = "ml-dtypes" }, "cuda-core[test]", "cuda-toolkit[cudart]==12.*"]
65+
test-cu13-ft = [ {include-group = "ml-dtypes" }, "cuda-core[test]", "cuda-toolkit[cudart]==13.*"]
6566

6667
[project.urls]
6768
homepage = "https://nvidia.github.io/cuda-python/"

cuda_core/tests/test_utils.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

55
import math
66

7+
# TODO: replace optional imports with pytest.importorskip
78
try:
89
import cupy as cp
910
except ImportError:
@@ -12,7 +13,12 @@
1213
from numba import cuda as numba_cuda
1314
except ImportError:
1415
numba_cuda = None
16+
try:
17+
import torch
18+
except ImportError:
19+
torch = None
1520
import cuda.core
21+
import ml_dtypes
1622
import numpy as np
1723
import pytest
1824
from cuda.core import Device
@@ -21,6 +27,12 @@
2127
from pytest import param
2228

2329

30+
def _get_cupy_version_major() -> int | None:
31+
if cp is None:
32+
return None
33+
return int(cp.__version__.split(".")[0])
34+
35+
2436
def test_cast_to_3_tuple_success():
2537
c3t = cuda.core._utils.cuda_utils.cast_to_3_tuple
2638
assert c3t("", ()) == (1, 1, 1)
@@ -524,3 +536,85 @@ def test_from_array_interface_unsupported_strides(init_cuda):
524536
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
525537
# TODO: ideally this would raise on construction
526538
smv.strides # noqa: B018
539+
540+
541+
@pytest.mark.parametrize(
542+
"slices",
543+
[
544+
param((slice(None), slice(None)), id="contiguous"),
545+
param((slice(None, None, 2), slice(1, None, 2)), id="strided"),
546+
],
547+
)
548+
@pytest.mark.skipif(cp is None, reason="CuPy is not installed")
549+
@pytest.mark.skipif(cp is not None and _get_cupy_version_major() < 14, reason="CuPy version is less than 14.0.0")
550+
def test_ml_dtypes_bfloat16_dlpack(init_cuda, slices):
551+
a = cp.array([1, 2, 3, 4, 5, 6], dtype=ml_dtypes.bfloat16).reshape(2, 3)[slices]
552+
smv = StridedMemoryView.from_dlpack(a, stream_ptr=0)
553+
554+
assert smv.size == a.size
555+
assert smv.dtype == np.dtype("bfloat16")
556+
assert smv.dtype == np.dtype(ml_dtypes.bfloat16)
557+
assert smv.shape == a.shape
558+
assert smv.ptr == a.data.ptr
559+
assert smv.device_id == init_cuda.device_id
560+
assert smv.is_device_accessible is True
561+
assert smv.exporting_obj is a
562+
assert smv.readonly is a.__cuda_array_interface__["data"][1]
563+
564+
strides_in_counts = convert_strides_to_counts(a.strides, a.dtype.itemsize)
565+
if a.flags["C_CONTIGUOUS"]:
566+
assert smv.strides in (None, strides_in_counts)
567+
else:
568+
assert smv.strides == strides_in_counts
569+
570+
571+
@pytest.mark.parametrize(
572+
"slices",
573+
[
574+
param((slice(None), slice(None)), id="contiguous"),
575+
param((slice(None, None, 2), slice(1, None, 2)), id="strided"),
576+
],
577+
)
578+
@pytest.mark.skipif(torch is None, reason="PyTorch is not installed")
579+
def test_ml_dtypes_bfloat16_torch_dlpack(init_cuda, slices):
580+
a = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.bfloat16, device="cuda").reshape(2, 3)[slices]
581+
smv = StridedMemoryView.from_dlpack(a, stream_ptr=0)
582+
583+
assert smv.size == a.numel()
584+
assert smv.dtype == np.dtype("bfloat16")
585+
assert smv.dtype == np.dtype(ml_dtypes.bfloat16)
586+
assert smv.shape == tuple(a.shape)
587+
assert smv.ptr == a.data_ptr()
588+
assert smv.device_id == init_cuda.device_id
589+
assert smv.is_device_accessible is True
590+
assert smv.exporting_obj is a
591+
592+
# PyTorch stride() returns strides in elements, convert to bytes first
593+
strides_in_bytes = tuple(s * a.element_size() for s in a.stride())
594+
strides_in_counts = convert_strides_to_counts(strides_in_bytes, a.element_size())
595+
if a.is_contiguous():
596+
assert smv.strides in (None, strides_in_counts)
597+
else:
598+
assert smv.strides == strides_in_counts
599+
600+
601+
@pytest.fixture
602+
def no_ml_dtypes(monkeypatch):
603+
monkeypatch.setattr("cuda.core._memoryview.bfloat16", None)
604+
yield
605+
606+
607+
@pytest.mark.parametrize(
608+
"api",
609+
[
610+
param(StridedMemoryView.from_dlpack, id="from_dlpack"),
611+
param(StridedMemoryView.from_any_interface, id="from_any_interface"),
612+
],
613+
)
614+
@pytest.mark.skipif(cp is None, reason="CuPy is not installed")
615+
@pytest.mark.skipif(cp is not None and _get_cupy_version_major() < 14, reason="CuPy version is less than 14.0.0")
616+
def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, api):
617+
a = cp.array([1, 2, 3], dtype="bfloat16")
618+
smv = api(a, stream_ptr=0)
619+
with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"):
620+
smv.dtype # noqa: B018

0 commit comments

Comments
 (0)