22# SPDX-License-Identifier: Apache-2.0
33
44import gc
5- import os
5+ import importlib . util
66import sys
7+ from pathlib import Path
78
89import pytest
910
@@ -12,24 +13,38 @@ class SampleTestError(Exception):
1213 pass
1314
1415
15- def parse_python_script (filepath ):
16- if not filepath .endswith (".py" ):
17- raise ValueError (f"{ filepath } not supported" )
18- with open (filepath , encoding = "utf-8" ) as f :
19- script = f .read ()
20- return script
16+ def run_example (parent_dir : str , rel_path_to_example : str , env = None ) -> None :
17+ fullpath = Path (parent_dir ) / rel_path_to_example
18+ module_name = fullpath .stem
2119
20+ old_sys_path = sys .path .copy ()
21+ old_argv = sys .argv
2222
23- def run_example (samples_path , filename , env = None ):
24- fullpath = os .path .join (samples_path , filename )
25- script = parse_python_script (fullpath )
2623 try :
27- old_argv = sys .argv
28- sys .argv = [fullpath ]
29- old_sys_path = sys .path .copy ()
30- sys .path .append (samples_path )
31- # TODO: Refactor the examples to give them a common callable `main()` to avoid needing to use exec here?
32- exec (script , env if env else {}) # noqa: S102
24+ sys .path .append (parent_dir )
25+ sys .argv = [str (fullpath )]
26+
27+ # Collect metadata for file 'module_name' located at 'fullpath'.
28+ # CASE: file does not exist -> spec is none.
29+ # CASE: file is not .py -> spec is none.
30+ # CASE: file does not have proper loader (module.spec.__loader__) -> spec.loader is none.
31+ spec = importlib .util .spec_from_file_location (module_name , fullpath )
32+
33+ if spec is None or spec .loader is None :
34+ raise ImportError (f"Failed to load spec for { rel_path_to_example } " )
35+
36+ # Otherwise convert the spec to a module, then run the module.
37+ module = importlib .util .module_from_spec (spec )
38+ sys .modules [module_name ] = module
39+
40+ # This runs top-level code.
41+ # CASE: exec() -> top-level code is implicitly run.
42+ spec .loader .exec_module (module )
43+
44+ # CASE: main() -> we find main and call it below.
45+ if hasattr (module , "main" ):
46+ module .main ()
47+
3348 except ImportError as e :
3449 # for samples requiring any of optional dependencies
3550 for m in ("cupy" , "torch" ):
@@ -40,14 +55,16 @@ def run_example(samples_path, filename, env=None):
4055 raise
4156 except SystemExit :
4257 # for samples that early return due to any missing requirements
43- pytest .skip (f"skip { filename } " )
58+ pytest .skip (f"skip { rel_path_to_example } " )
4459 except Exception as e :
4560 msg = "\n "
46- msg += f"Got error ({ filename } ):\n "
61+ msg += f"Got error ({ rel_path_to_example } ):\n "
4762 msg += str (e )
4863 raise SampleTestError (msg ) from e
4964 finally :
5065 sys .path = old_sys_path
5166 sys .argv = old_argv
67+
5268 # further reduce the memory watermark
69+ sys .modules .pop (module_name , None )
5370 gc .collect ()
0 commit comments