diff --git a/openhtf/core/phase_descriptor.py b/openhtf/core/phase_descriptor.py index 7327c1a9..391949ab 100644 --- a/openhtf/core/phase_descriptor.py +++ b/openhtf/core/phase_descriptor.py @@ -124,6 +124,8 @@ class PhaseOptions(object): timeout will still apply when under the debugger. phase_name_case: Case formatting options for phase name. stop_on_measurement_fail: Whether to stop the test if any measurements fail. + prerequisites: List of phases that must be completed before this phase can + be run. Example Usages: @PhaseOptions(timeout_s=1) def PhaseFunc(test): pass @PhaseOptions(name='Phase({port})') def PhaseFunc(test, port, other_info): pass @@ -140,6 +142,7 @@ def PhaseFunc(test, port, other_info): pass run_under_pdb = attr.ib(type=bool, default=False) phase_name_case = attr.ib(type=PhaseNameCase, default=PhaseNameCase.KEEP) stop_on_measurement_fail = attr.ib(type=bool, default=False) + prerequisites = attr.ib(type=Optional[List[Any]], default=None) def format_strings(self, **kwargs: Any) -> 'PhaseOptions': """String substitution of name.""" @@ -173,6 +176,8 @@ def __call__(self, phase_func: PhaseT) -> 'PhaseDescriptor': phase.options.stop_on_measurement_fail = self.stop_on_measurement_fail if self.phase_name_case: phase.options.phase_name_case = self.phase_name_case + if self.prerequisites is not None: + phase.options.prerequisites = self.prerequisites return phase diff --git a/openhtf/core/phase_executor.py b/openhtf/core/phase_executor.py index 637dab47..0391f5a0 100644 --- a/openhtf/core/phase_executor.py +++ b/openhtf/core/phase_executor.py @@ -170,10 +170,12 @@ def __init__(self, phase_desc: phase_descriptor.PhaseDescriptor, self._phase_desc = phase_desc self._test_state = test_state self._subtest_rec = subtest_rec + self._phase_state = test_state.running_phase_state self._phase_execution_outcome = None # type: Optional[PhaseExecutionOutcome] def _thread_proc(self) -> None: """Execute the encompassed phase and save the result.""" + self._test_state.running_phase_state = self._phase_state # Call the phase, save the return value, or default it to CONTINUE. phase_return = self._phase_desc(self._test_state) if phase_return is None: @@ -320,9 +322,12 @@ def _execute_phase_once( phase_desc.name) return PhaseExecutionOutcome(phase_descriptor.PhaseResult.SKIP), None - override_result = None - with self.test_state.running_phase_context(phase_desc) as phase_state: + if id(phase_desc) in getattr(self.test_state, '_concurrent_nodes', set()): + ctx_mgr = self.test_state.concurrent_running_phase_context + else: + ctx_mgr = self.test_state.running_phase_context + with ctx_mgr(phase_desc) as phase_state: if subtest_rec: self.logger.debug('Executing phase %s under subtest %s (from %s)', phase_desc.name, phase_desc.func_location, diff --git a/openhtf/core/phase_graph.py b/openhtf/core/phase_graph.py new file mode 100644 index 00000000..13ce488c --- /dev/null +++ b/openhtf/core/phase_graph.py @@ -0,0 +1,156 @@ +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Phase Graph support for OpenHTF. + +PhaseGraph is a PhaseCollectionNode that manages its contained phases via +a topological sort based on their explicit prerequisites. +""" + +from typing import Any, Callable, Dict, Iterator, List, Optional, Text, Tuple, Type + +import attr +from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import phase_collections +from openhtf.core import phase_descriptor + + +class CyclicDependencyError(Exception): + """Raised when PhaseGraph phases have cyclic dependencies.""" + + +class MissingPrerequisiteError(Exception): + """Raised when a prerequisite is not defined in the graph.""" + + +@attr.s(slots=True, frozen=True, init=False) +class PhaseGraph(phase_collections.PhaseCollectionNode): + """A phase collection whose execution order is defined by a DAG.""" + + nodes = attr.ib(type=Tuple[phase_descriptor.PhaseDescriptor, ...]) + name = attr.ib(type=Optional[Text], default=None) + + def __init__( + self, + *args: phase_descriptor.PhaseCallableOrNodeT, + name: Optional[Text] = None, + nodes: Optional[Tuple[phase_descriptor.PhaseDescriptor, ...]] = None, + ): + super(PhaseGraph, self).__init__() + object.__setattr__(self, 'name', name) + + if nodes is not None: + args = args + tuple(nodes) + + flattened = list(phase_collections._recursive_flatten(args)) + # Verify elements are PhaseDescriptor instances for prerequisite matching + ph_desc_list = [] + for n in flattened: + if isinstance(n, phase_descriptor.PhaseDescriptor): + ph_desc_list.append(n) + else: + # Wrap or copy standard callables / nodes + ph_desc_list.append(phase_descriptor.PhaseDescriptor.wrap_or_copy(n)) + + topologically_sorted = self._validate_and_toposort(ph_desc_list) + object.__setattr__(self, 'nodes', tuple(topologically_sorted)) + + def _validate_and_toposort( + self, nodes: List[phase_descriptor.PhaseDescriptor] + ) -> List[phase_descriptor.PhaseDescriptor]: + """Validates the DAG structure and returns topologically sorted nodes.""" + name_to_node = {n.name: n for n in nodes} + + # Match prerequisites to actual nodes + adjacency = {n.name: set() for n in nodes} + for n in nodes: + if n.options.prerequisites is not None: + for pr in n.options.prerequisites: + pr_name = pr if isinstance(pr, str) else getattr(pr, 'name', None) + if not pr_name or pr_name not in name_to_node: + raise MissingPrerequisiteError( + f"Prerequisite '{pr_name}' for phase '{n.name}' not found in" + ' PhaseGraph.' + ) + adjacency[n.name].add(pr_name) + + # Perform topological sort using Kahn's algorithm/DFS cycle detection + visited = set() + temp_marked = set() + sorted_names = [] + + def _visit(node_name: str): + if node_name in temp_marked: + raise CyclicDependencyError(f"Cycle detected involving '{node_name}'") + if node_name in visited: + return + temp_marked.add(node_name) + for prereq_name in adjacency[node_name]: + _visit(prereq_name) + temp_marked.remove(node_name) + visited.add(node_name) + sorted_names.append(node_name) + + for n in nodes: + if n.name not in visited: + _visit(n.name) + + return [name_to_node[name] for name in sorted_names] + + def _asdict(self) -> Dict[Text, Any]: + return { + 'name': self.name, + 'nodes': [n._asdict() for n in self.nodes], + } + + def with_args(self, **kwargs: Any) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.with_args(**kwargs) for n in self.nodes), + name=util.format_string(self.name, kwargs), + ) + + def with_plugs(self, **subplugs: Type[base_plugs.BasePlug]) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes), + name=util.format_string(self.name, subplugs), + ) + + def load_code_info(self) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.load_code_info() for n in self.nodes), + name=self.name, + ) + + def apply_to_all_phases( + self, + func: Callable[ + [phase_descriptor.PhaseDescriptor], phase_descriptor.PhaseDescriptor + ], + ) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes), + name=self.name, + ) + + def filter_by_type(self, node_cls: Type[Any]) -> Iterator[Any]: + for node in self.nodes: + if isinstance(node, node_cls): + yield node + if isinstance(node, phase_collections.PhaseCollectionNode): + for sub_n in node.filter_by_type(node_cls): + yield sub_n diff --git a/openhtf/core/test_executor.py b/openhtf/core/test_executor.py index d5a0d2b5..3c6fefef 100644 --- a/openhtf/core/test_executor.py +++ b/openhtf/core/test_executor.py @@ -13,15 +13,17 @@ # limitations under the License. """TestExecutor executes tests.""" +import concurrent.futures import contextlib import enum import logging +import multiprocessing import pstats import sys import tempfile import threading import traceback -from typing import Iterator, List, Optional, Text, Type, TYPE_CHECKING +from typing import Iterator, List, Optional, TYPE_CHECKING, Text, Type from openhtf import util from openhtf.core import base_plugs @@ -30,6 +32,7 @@ from openhtf.core import phase_collections from openhtf.core import phase_descriptor from openhtf.core import phase_executor +from openhtf.core import phase_graph from openhtf.core import phase_group from openhtf.core import phase_nodes from openhtf.core import test_record @@ -602,6 +605,84 @@ def _execute_phase_group(self, group: phase_group.PhaseGroup, teardown_ret = _ExecutorReturn.CONTINUE return _more_critical(main_ret, teardown_ret) + def _execute_phase_graph( + self, + graph: phase_graph.PhaseGraph, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool, + ) -> _ExecutorReturn: + """Executes the phases in a phase graph concurrently according to DAG ordering.""" + + if graph.name: + self.logger.debug('Entering PhaseGraph %s', graph.name) + + nodes = list(graph.nodes) + + completed_phases = set() + failed_phases = set() + running_futures = {} + + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(32, (multiprocessing.cpu_count() or 1) + 4) + ) as pool: + while len(completed_phases) + len(failed_phases) < len(nodes): + if self._abort.is_set() or self._full_abort.is_set(): + for fut in running_futures: + fut.cancel() + return _ExecutorReturn.TERMINAL + + # Submit any unblocked and un-scheduled phase + made_progress = False + for node in nodes: + if ( + node.name in completed_phases + or node.name in failed_phases + or node.name in running_futures.values() + ): + continue + + prerequisite_satisfied = True + if node.options.prerequisites: + for pr in node.options.prerequisites: + pr_name = pr if isinstance(pr, str) else getattr(pr, 'name', None) + if pr_name not in completed_phases: + prerequisite_satisfied = False + break + + if prerequisite_satisfied: + self.running_test_state.add_concurrent_node(id(node)) + fut = pool.submit( + self._execute_phase, node, subtest_rec, in_teardown + ) + running_futures[fut] = node.name + made_progress = True + + if not running_futures and not made_progress: + # Blocked completely / cycle or missing prerequisites + return _ExecutorReturn.TERMINAL + + # Wait for at least one currently running future to complete + done, _ = concurrent.futures.wait( + running_futures.keys(), + return_when=concurrent.futures.FIRST_COMPLETED, + ) + + for fut in done: + p_name = running_futures.pop(fut) + try: + res = fut.result() + if res == _ExecutorReturn.TERMINAL: + failed_phases.add(p_name) + return _ExecutorReturn.TERMINAL + else: + completed_phases.add(p_name) + except Exception: # pylint: disable=broad-except + self.logger.exception('Phase worker thread raised an exception.') + failed_phases.add(p_name) + return _ExecutorReturn.TERMINAL + + return _ExecutorReturn.CONTINUE + def _execute_node(self, node: phase_nodes.PhaseNode, subtest_rec: Optional[test_record.SubtestRecord], in_teardown: bool) -> _ExecutorReturn: @@ -613,6 +694,8 @@ def _execute_node(self, node: phase_nodes.PhaseNode, return self._execute_sequence(node, subtest_rec, in_teardown) if isinstance(node, phase_group.PhaseGroup): return self._execute_phase_group(node, subtest_rec, in_teardown) + if isinstance(node, phase_graph.PhaseGraph): + return self._execute_phase_graph(node, subtest_rec, in_teardown) if isinstance(node, phase_descriptor.PhaseDescriptor): return self._execute_phase(node, subtest_rec, in_teardown) if isinstance(node, phase_branches.Checkpoint): diff --git a/openhtf/core/test_state.py b/openhtf/core/test_state.py index 89269650..ec017939 100644 --- a/openhtf/core/test_state.py +++ b/openhtf/core/test_state.py @@ -33,6 +33,7 @@ import os import socket import sys +import threading from typing import Any, Dict, Iterator, List, Optional, Set, TYPE_CHECKING, Text, Tuple, Union import attr @@ -137,6 +138,10 @@ def __init__(self, test_desc: 'test_descriptor.TestDescriptor', test_options: test_options passed through from Test. """ super(TestState, self).__init__() + self.state_lock = threading.RLock() + self._thread_local = threading.local() + self._base_running_phase_state = None + self._concurrent_nodes = set() self._status = self.Status.WAITING_FOR_TEST_START # type: TestState.Status self.test_record = test_record.TestRecord( @@ -180,6 +185,22 @@ def logger(self) -> logging.Logger: 'Calling `logger` attribute while phase not running; use state_logger ' 'instead.') + @property + def running_phase_state(self): + return getattr( + self._thread_local, 'phase_state', self._base_running_phase_state + ) + + def add_concurrent_node(self, node_id: int) -> None: + """Adds a node to the set of nodes that are running concurrently.""" + with self.state_lock: + self._concurrent_nodes.add(node_id) + + @running_phase_state.setter + def running_phase_state(self, value): + self._thread_local.phase_state = value + self._base_running_phase_state = value + @property def test_api(self) -> 'test_descriptor.TestApi': """Create a TestApi for access to this TestState. @@ -196,14 +217,20 @@ def test_api(self) -> 'test_descriptor.TestApi': """ if not self.running_phase_state: raise ValueError('test_api only available when phase is running.') - if not self._running_test_api: - self._running_test_api = openhtf.TestApi( + api = getattr(self._thread_local, 'test_api', None) + if ( + not api + or getattr(api, 'running_phase_state', None) != self.running_phase_state + ): + api = openhtf.TestApi( measurements=measurements.Collection( self.running_phase_state.measurements), running_phase_state=self.running_phase_state, running_test_state=self, ) - return self._running_test_api + self._thread_local.test_api = api + self._running_test_api = api + return api def get_attachment(self, attachment_name: Text) -> Optional[test_record.Attachment]: @@ -296,9 +323,30 @@ def running_phase_context( phase_state.finalize() self.test_record.add_phase_record(phase_state.phase_record) self.running_phase_state = None - self._running_test_api = None self.notify_update() # Phase finished. + @contextlib.contextmanager + def concurrent_running_phase_context( + self, + phase_desc: phase_descriptor.PhaseDescriptor) -> Iterator['PhaseState']: + """Create an isolated, thread-safe context for a concurrent PhaseGraph node.""" + phase_logger = self.state_logger.getChild('phase.' + phase_desc.name) + phase_state = PhaseState.from_descriptor(phase_desc, self, phase_logger) + + # Set running_phase_state and _running_test_api for the thread's context + with self.state_lock: + self.running_phase_state = phase_state + self._running_test_api = None + self.notify_update() + + try: + yield phase_state + finally: + phase_state.finalize() + with self.state_lock: + self.test_record.add_phase_record(phase_state.phase_record) + self.notify_update() + def as_base_types(self) -> Dict[Text, Any]: """Convert to a dict representation composed exclusively of base types.""" running_phase_state = None diff --git a/test/core/phase_graph_test.py b/test/core/phase_graph_test.py new file mode 100644 index 00000000..43bc192b --- /dev/null +++ b/test/core/phase_graph_test.py @@ -0,0 +1,50 @@ +import unittest + +from openhtf.core import phase_descriptor +from openhtf.core import phase_graph + + +@phase_descriptor.PhaseOptions(name='phase_a') +def phase_a(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_b', prerequisites=['phase_a']) +def phase_b(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_c', prerequisites=['phase_b']) +def phase_c(): + pass + + +@phase_descriptor.PhaseOptions(name='cycle_1', prerequisites=['cycle_2']) +def cycle_1(): + pass + + +@phase_descriptor.PhaseOptions(name='cycle_2', prerequisites=['cycle_1']) +def cycle_2(): + pass + + +class PhaseGraphTest(unittest.TestCase): + + def test_topological_sorting(self): + # Provide in random order, must sort to A -> B -> C + graph = phase_graph.PhaseGraph(phase_c, phase_a, phase_b) + self.assertEqual([node.name for node in graph.nodes], + ['phase_a', 'phase_b', 'phase_c']) + + def test_cyclic_dependency_raises(self): + with self.assertRaises(phase_graph.CyclicDependencyError): + phase_graph.PhaseGraph(cycle_1, cycle_2) + + def test_missing_prerequisite_raises(self): + with self.assertRaises(phase_graph.MissingPrerequisiteError): + phase_graph.PhaseGraph(phase_c) + + +if __name__ == '__main__': + unittest.main()