|
1 | | -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE |
4 | 4 |
|
5 | 5 | import math |
6 | 6 |
|
| 7 | +# TODO: replace optional imports with pytest.importorskip |
7 | 8 | try: |
8 | 9 | import cupy as cp |
9 | 10 | except ImportError: |
|
12 | 13 | from numba import cuda as numba_cuda |
13 | 14 | except ImportError: |
14 | 15 | numba_cuda = None |
| 16 | +try: |
| 17 | + import torch |
| 18 | +except ImportError: |
| 19 | + torch = None |
15 | 20 | import cuda.core |
| 21 | +import ml_dtypes |
16 | 22 | import numpy as np |
17 | 23 | import pytest |
18 | 24 | from cuda.core import Device |
|
21 | 27 | from pytest import param |
22 | 28 |
|
23 | 29 |
|
| 30 | +def _get_cupy_version_major() -> int | None: |
| 31 | + if cp is None: |
| 32 | + return None |
| 33 | + return int(cp.__version__.split(".")[0]) |
| 34 | + |
| 35 | + |
24 | 36 | def test_cast_to_3_tuple_success(): |
25 | 37 | c3t = cuda.core._utils.cuda_utils.cast_to_3_tuple |
26 | 38 | assert c3t("", ()) == (1, 1, 1) |
@@ -524,3 +536,85 @@ def test_from_array_interface_unsupported_strides(init_cuda): |
524 | 536 | with pytest.raises(ValueError, match="strides must be divisible by itemsize"): |
525 | 537 | # TODO: ideally this would raise on construction |
526 | 538 | smv.strides # noqa: B018 |
| 539 | + |
| 540 | + |
| 541 | +@pytest.mark.parametrize( |
| 542 | + "slices", |
| 543 | + [ |
| 544 | + param((slice(None), slice(None)), id="contiguous"), |
| 545 | + param((slice(None, None, 2), slice(1, None, 2)), id="strided"), |
| 546 | + ], |
| 547 | +) |
| 548 | +@pytest.mark.skipif(cp is None, reason="CuPy is not installed") |
| 549 | +@pytest.mark.skipif(cp is not None and _get_cupy_version_major() < 14, reason="CuPy version is less than 14.0.0") |
| 550 | +def test_ml_dtypes_bfloat16_dlpack(init_cuda, slices): |
| 551 | + a = cp.array([1, 2, 3, 4, 5, 6], dtype=ml_dtypes.bfloat16).reshape(2, 3)[slices] |
| 552 | + smv = StridedMemoryView.from_dlpack(a, stream_ptr=0) |
| 553 | + |
| 554 | + assert smv.size == a.size |
| 555 | + assert smv.dtype == np.dtype("bfloat16") |
| 556 | + assert smv.dtype == np.dtype(ml_dtypes.bfloat16) |
| 557 | + assert smv.shape == a.shape |
| 558 | + assert smv.ptr == a.data.ptr |
| 559 | + assert smv.device_id == init_cuda.device_id |
| 560 | + assert smv.is_device_accessible is True |
| 561 | + assert smv.exporting_obj is a |
| 562 | + assert smv.readonly is a.__cuda_array_interface__["data"][1] |
| 563 | + |
| 564 | + strides_in_counts = convert_strides_to_counts(a.strides, a.dtype.itemsize) |
| 565 | + if a.flags["C_CONTIGUOUS"]: |
| 566 | + assert smv.strides in (None, strides_in_counts) |
| 567 | + else: |
| 568 | + assert smv.strides == strides_in_counts |
| 569 | + |
| 570 | + |
| 571 | +@pytest.mark.parametrize( |
| 572 | + "slices", |
| 573 | + [ |
| 574 | + param((slice(None), slice(None)), id="contiguous"), |
| 575 | + param((slice(None, None, 2), slice(1, None, 2)), id="strided"), |
| 576 | + ], |
| 577 | +) |
| 578 | +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") |
| 579 | +def test_ml_dtypes_bfloat16_torch_dlpack(init_cuda, slices): |
| 580 | + a = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.bfloat16, device="cuda").reshape(2, 3)[slices] |
| 581 | + smv = StridedMemoryView.from_dlpack(a, stream_ptr=0) |
| 582 | + |
| 583 | + assert smv.size == a.numel() |
| 584 | + assert smv.dtype == np.dtype("bfloat16") |
| 585 | + assert smv.dtype == np.dtype(ml_dtypes.bfloat16) |
| 586 | + assert smv.shape == tuple(a.shape) |
| 587 | + assert smv.ptr == a.data_ptr() |
| 588 | + assert smv.device_id == init_cuda.device_id |
| 589 | + assert smv.is_device_accessible is True |
| 590 | + assert smv.exporting_obj is a |
| 591 | + |
| 592 | + # PyTorch stride() returns strides in elements, convert to bytes first |
| 593 | + strides_in_bytes = tuple(s * a.element_size() for s in a.stride()) |
| 594 | + strides_in_counts = convert_strides_to_counts(strides_in_bytes, a.element_size()) |
| 595 | + if a.is_contiguous(): |
| 596 | + assert smv.strides in (None, strides_in_counts) |
| 597 | + else: |
| 598 | + assert smv.strides == strides_in_counts |
| 599 | + |
| 600 | + |
| 601 | +@pytest.fixture |
| 602 | +def no_ml_dtypes(monkeypatch): |
| 603 | + monkeypatch.setattr("cuda.core._memoryview.bfloat16", None) |
| 604 | + yield |
| 605 | + |
| 606 | + |
| 607 | +@pytest.mark.parametrize( |
| 608 | + "api", |
| 609 | + [ |
| 610 | + param(StridedMemoryView.from_dlpack, id="from_dlpack"), |
| 611 | + param(StridedMemoryView.from_any_interface, id="from_any_interface"), |
| 612 | + ], |
| 613 | +) |
| 614 | +@pytest.mark.skipif(cp is None, reason="CuPy is not installed") |
| 615 | +@pytest.mark.skipif(cp is not None and _get_cupy_version_major() < 14, reason="CuPy version is less than 14.0.0") |
| 616 | +def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, api): |
| 617 | + a = cp.array([1, 2, 3], dtype="bfloat16") |
| 618 | + smv = api(a, stream_ptr=0) |
| 619 | + with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"): |
| 620 | + smv.dtype # noqa: B018 |
0 commit comments