Skip to content

Commit 534aa0a

Browse files
committed
modernize examples test harness
1 parent 463ed53 commit 534aa0a

2 files changed

Lines changed: 46 additions & 28 deletions

File tree

cuda_core/tests/example_tests/test_basic_examples.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,22 @@
33

44
# If we have subcategories of examples in the future, this file can be split along those lines
55

6-
import glob
7-
import os
6+
from pathlib import Path
87

98
import pytest
10-
from cuda.core import Device
119

1210
from .utils import run_example
1311

14-
samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
15-
sample_files = glob.glob(samples_path + "**/*.py", recursive=True)
12+
# not dividing, but navigating into the "examples" directory.
13+
EXAMPLES_DIR = Path(__file__).resolve().parent.parent.parent / "examples"
1614

15+
# recursively glob for test files in examples directory, sort for deterministic
16+
# test runs. Relative paths offer cleaner output when tests fail.
17+
SAMPLE_FILES = sorted([str(p.relative_to(EXAMPLES_DIR)) for p in EXAMPLES_DIR.glob("**/*.py")])
1718

18-
@pytest.mark.parametrize("example", sample_files)
19+
20+
@pytest.mark.parametrize("example_rel_path", SAMPLE_FILES)
1921
class TestExamples:
20-
def test_example(self, example, deinit_cuda):
21-
run_example(samples_path, example)
22-
if Device().device_id != 0:
23-
Device(0).set_current()
22+
# deinit_cuda is defined in conftest.py and pops the cuda context automatically.
23+
def test_example(self, example_rel_path: str, deinit_cuda) -> None:
24+
run_example(str(EXAMPLES_DIR), example_rel_path)

cuda_core/tests/example_tests/utils.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import gc
5-
import os
5+
import importlib.util
66
import sys
7+
from pathlib import Path
78

89
import pytest
910

@@ -12,24 +13,38 @@ class SampleTestError(Exception):
1213
pass
1314

1415

15-
def parse_python_script(filepath):
16-
if not filepath.endswith(".py"):
17-
raise ValueError(f"{filepath} not supported")
18-
with open(filepath, encoding="utf-8") as f:
19-
script = f.read()
20-
return script
16+
def run_example(parent_dir: str, rel_path_to_example: str, env=None) -> None:
17+
fullpath = Path(parent_dir) / rel_path_to_example
18+
module_name = fullpath.stem
2119

20+
old_sys_path = sys.path.copy()
21+
old_argv = sys.argv
2222

23-
def run_example(samples_path, filename, env=None):
24-
fullpath = os.path.join(samples_path, filename)
25-
script = parse_python_script(fullpath)
2623
try:
27-
old_argv = sys.argv
28-
sys.argv = [fullpath]
29-
old_sys_path = sys.path.copy()
30-
sys.path.append(samples_path)
31-
# TODO: Refactor the examples to give them a common callable `main()` to avoid needing to use exec here?
32-
exec(script, env if env else {}) # noqa: S102
24+
sys.path.append(parent_dir)
25+
sys.argv = [str(fullpath)]
26+
27+
# Collect metadata for file 'module_name' located at 'fullpath'.
28+
# CASE: file does not exist -> spec is none.
29+
# CASE: file is not .py -> spec is none.
30+
# CASE: file does not have proper loader (module.spec.__loader__) -> spec.loader is none.
31+
spec = importlib.util.spec_from_file_location(module_name, fullpath)
32+
33+
if spec is None or spec.loader is None:
34+
raise ImportError(f"Failed to load spec for {rel_path_to_example}")
35+
36+
# Otherwise convert the spec to a module, then run the module.
37+
module = importlib.util.module_from_spec(spec)
38+
sys.modules[module_name] = module
39+
40+
# This runs top-level code.
41+
# CASE: exec() -> top-level code is implicitly run.
42+
spec.loader.exec_module(module)
43+
44+
# CASE: main() -> we find main and call it below.
45+
if hasattr(module, "main"):
46+
module.main()
47+
3348
except ImportError as e:
3449
# for samples requiring any of optional dependencies
3550
for m in ("cupy", "torch"):
@@ -40,14 +55,16 @@ def run_example(samples_path, filename, env=None):
4055
raise
4156
except SystemExit:
4257
# for samples that early return due to any missing requirements
43-
pytest.skip(f"skip {filename}")
58+
pytest.skip(f"skip {rel_path_to_example}")
4459
except Exception as e:
4560
msg = "\n"
46-
msg += f"Got error ({filename}):\n"
61+
msg += f"Got error ({rel_path_to_example}):\n"
4762
msg += str(e)
4863
raise SampleTestError(msg) from e
4964
finally:
5065
sys.path = old_sys_path
5166
sys.argv = old_argv
67+
5268
# further reduce the memory watermark
69+
sys.modules.pop(module_name, None)
5370
gc.collect()

0 commit comments

Comments
 (0)