-
Notifications
You must be signed in to change notification settings - Fork 275
Expand file tree
/
Copy pathhelper_cuda.py
More file actions
48 lines (38 loc) · 1.53 KB
/
helper_cuda.py
File metadata and controls
48 lines (38 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
from cuda.bindings import driver as cuda
from cuda.bindings import nvrtc
from cuda.bindings import runtime as cudart
from .helper_string import check_cmd_line_flag, get_cmd_line_argument_int
def _cuda_get_error_enum(error):
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, cudart.cudaError_t):
return cudart.cudaGetErrorName(error)[1]
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError(f"Unknown error type: {error}")
def check_cuda_errors(result):
if result[0].value:
raise RuntimeError(f"CUDA error code={result[0].value}({_cuda_get_error_enum(result[0])})")
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def find_cuda_device():
dev_id = 0
if check_cmd_line_flag("device="):
dev_id = get_cmd_line_argument_int("device=")
check_cuda_errors(cudart.cudaSetDevice(dev_id))
return dev_id
def find_cuda_device_drv():
dev_id = 0
if check_cmd_line_flag("device="):
dev_id = get_cmd_line_argument_int("device=")
check_cuda_errors(cuda.cuInit(0))
cu_device = check_cuda_errors(cuda.cuDeviceGet(dev_id))
return cu_device