Skip to content

Commit 9ff8d0f

Browse files
committed
Teach the TMA example where to find libcudacxx headers.
Use the toolkit include and optional cccl include roots when compiling the wrapper-based example so NVRTC can resolve cuda/barrier outside the test harness. Made-with: Cursor
1 parent e67e9d3 commit 9ff8d0f

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

cuda_core/examples/tma_tensor_map.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
# ################################################################################
2424

25+
import os
2526
import sys
2627

2728
import cupy as cp
@@ -100,6 +101,25 @@
100101
}
101102
"""
102103

104+
105+
def _get_cccl_include_paths():
106+
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
107+
if cuda_path is None:
108+
print("This example requires CUDA_PATH or CUDA_HOME to point to a CUDA toolkit.", file=sys.stderr)
109+
sys.exit(1)
110+
111+
cuda_include = os.path.join(cuda_path, "include")
112+
if not os.path.isdir(cuda_include):
113+
print(f"CUDA include directory not found: {cuda_include}", file=sys.stderr)
114+
sys.exit(1)
115+
116+
include_path = [cuda_include]
117+
cccl_include = os.path.join(cuda_include, "cccl")
118+
if os.path.isdir(cccl_include):
119+
include_path.insert(0, cccl_include)
120+
return include_path
121+
122+
103123
def main():
104124
# -----------------------------------------------------------------------
105125
# Check for Hopper+ GPU
@@ -113,14 +133,15 @@ def main():
113133
)
114134
sys.exit(0)
115135
dev.set_current()
136+
include_path = _get_cccl_include_paths()
116137

117138
# -----------------------------------------------------------------------
118139
# Compile the kernel
119140
# -----------------------------------------------------------------------
120141
prog = Program(
121142
code,
122143
code_type="c++",
123-
options=ProgramOptions(std="c++17", arch=f"sm_{dev.arch}"),
144+
options=ProgramOptions(std="c++17", arch=f"sm_{dev.arch}", include_path=include_path),
124145
)
125146
mod = prog.compile("cubin")
126147
ker = mod.get_kernel("tma_copy")

0 commit comments

Comments
 (0)