-
Notifications
You must be signed in to change notification settings - Fork 275
Expand file tree
/
Copy patharch_check.py
More file actions
68 lines (56 loc) · 2.46 KB
/
arch_check.py
File metadata and controls
68 lines (56 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
from contextlib import contextmanager
from functools import cache
import pytest
from cuda.bindings import nvml
@cache
def hardware_supports_nvml():
"""
Tries to call the simplest NVML API possible to see if just the basics
works. If not we are probably on one of the platforms where NVML is not
supported at all (e.g. Jetson Orin).
"""
nvml.init_v2()
try:
nvml.system_get_driver_branch()
except (nvml.NotSupportedError, nvml.UnknownError):
return False
else:
return True
finally:
nvml.shutdown()
@contextmanager
def unsupported_before(device: int, expected_device_arch: nvml.DeviceArch | str | None):
device_arch = nvml.device_get_architecture(device)
if isinstance(expected_device_arch, nvml.DeviceArch):
expected_device_arch_int = int(expected_device_arch)
elif expected_device_arch == "FERMI":
expected_device_arch_int = 1
else:
expected_device_arch_int = 0
if expected_device_arch is None or expected_device_arch == "HAS_INFOROM" or device_arch == nvml.DeviceArch.UNKNOWN:
# In this case, we don't /know/ if it will fail, but we are ok if it
# does or does not.
# TODO: There are APIs that are documented as supported only if the
# device has an InfoROM, but I couldn't find a way to detect that. For
# now, they are just handled as "possibly failing".
try:
yield
except nvml.NotSupportedError:
# The API call raised NotSupportedError, so we skip the test, but
# don't fail it
pytest.skip(
f"Unsupported call for device architecture {nvml.DeviceArch(device_arch).name} "
f"on device '{nvml.device_get_name(device)}'"
)
# If the API call worked, just continue
elif int(device_arch) < expected_device_arch_int:
# In this case, we /know/ if will fail, and we want to assert that it does.
with pytest.raises(nvml.NotSupportedError):
yield
# The above call was unsupported, so the rest of the test is skipped
pytest.skip(f"Unsupported before {expected_device_arch.name}, got {nvml.device_get_name(device)}")
else:
# In this case, we /know/ it should work, and if it fails, the test should fail.
yield