Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openhtf/core/phase_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class PhaseOptions(object):
timeout will still apply when under the debugger.
phase_name_case: Case formatting options for phase name.
stop_on_measurement_fail: Whether to stop the test if any measurements fail.
prerequisites: List of phases that must be completed before this phase can
be run.
Example Usages: @PhaseOptions(timeout_s=1)
def PhaseFunc(test): pass @PhaseOptions(name='Phase({port})')
def PhaseFunc(test, port, other_info): pass
Expand All @@ -140,6 +142,7 @@ def PhaseFunc(test, port, other_info): pass
run_under_pdb = attr.ib(type=bool, default=False)
phase_name_case = attr.ib(type=PhaseNameCase, default=PhaseNameCase.KEEP)
stop_on_measurement_fail = attr.ib(type=bool, default=False)
prerequisites = attr.ib(type=Optional[List[Any]], default=None)

def format_strings(self, **kwargs: Any) -> 'PhaseOptions':
"""String substitution of name."""
Expand Down Expand Up @@ -173,6 +176,8 @@ def __call__(self, phase_func: PhaseT) -> 'PhaseDescriptor':
phase.options.stop_on_measurement_fail = self.stop_on_measurement_fail
if self.phase_name_case:
phase.options.phase_name_case = self.phase_name_case
if self.prerequisites is not None:
phase.options.prerequisites = self.prerequisites
return phase


Expand Down
9 changes: 7 additions & 2 deletions openhtf/core/phase_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def __init__(self, phase_desc: phase_descriptor.PhaseDescriptor,
self._phase_desc = phase_desc
self._test_state = test_state
self._subtest_rec = subtest_rec
self._phase_state = test_state.running_phase_state
self._phase_execution_outcome = None # type: Optional[PhaseExecutionOutcome]

def _thread_proc(self) -> None:
"""Execute the encompassed phase and save the result."""
self._test_state.running_phase_state = self._phase_state
# Call the phase, save the return value, or default it to CONTINUE.
phase_return = self._phase_desc(self._test_state)
if phase_return is None:
Expand Down Expand Up @@ -320,9 +322,12 @@ def _execute_phase_once(
phase_desc.name)
return PhaseExecutionOutcome(phase_descriptor.PhaseResult.SKIP), None


override_result = None
with self.test_state.running_phase_context(phase_desc) as phase_state:
if id(phase_desc) in getattr(self.test_state, '_concurrent_nodes', set()):
ctx_mgr = self.test_state.concurrent_running_phase_context
else:
ctx_mgr = self.test_state.running_phase_context
with ctx_mgr(phase_desc) as phase_state:
if subtest_rec:
self.logger.debug('Executing phase %s under subtest %s (from %s)',
phase_desc.name, phase_desc.func_location,
Expand Down
156 changes: 156 additions & 0 deletions openhtf/core/phase_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2026 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Phase Graph support for OpenHTF.

PhaseGraph is a PhaseCollectionNode that manages its contained phases via
a topological sort based on their explicit prerequisites.
"""

from typing import Any, Callable, Dict, Iterator, List, Optional, Text, Tuple, Type

import attr
from openhtf import util
from openhtf.core import base_plugs
from openhtf.core import phase_collections
from openhtf.core import phase_descriptor


class CyclicDependencyError(Exception):
"""Raised when PhaseGraph phases have cyclic dependencies."""


class MissingPrerequisiteError(Exception):
"""Raised when a prerequisite is not defined in the graph."""


@attr.s(slots=True, frozen=True, init=False)
class PhaseGraph(phase_collections.PhaseCollectionNode):
"""A phase collection whose execution order is defined by a DAG."""

nodes = attr.ib(type=Tuple[phase_descriptor.PhaseDescriptor, ...])
name = attr.ib(type=Optional[Text], default=None)

def __init__(
self,
*args: phase_descriptor.PhaseCallableOrNodeT,
name: Optional[Text] = None,
nodes: Optional[Tuple[phase_descriptor.PhaseDescriptor, ...]] = None,
):
super(PhaseGraph, self).__init__()
object.__setattr__(self, 'name', name)

if nodes is not None:
args = args + tuple(nodes)

flattened = list(phase_collections._recursive_flatten(args))
# Verify elements are PhaseDescriptor instances for prerequisite matching
ph_desc_list = []
for n in flattened:
if isinstance(n, phase_descriptor.PhaseDescriptor):
ph_desc_list.append(n)
else:
# Wrap or copy standard callables / nodes
ph_desc_list.append(phase_descriptor.PhaseDescriptor.wrap_or_copy(n))

topologically_sorted = self._validate_and_toposort(ph_desc_list)
object.__setattr__(self, 'nodes', tuple(topologically_sorted))

def _validate_and_toposort(
self, nodes: List[phase_descriptor.PhaseDescriptor]
) -> List[phase_descriptor.PhaseDescriptor]:
"""Validates the DAG structure and returns topologically sorted nodes."""
name_to_node = {n.name: n for n in nodes}

# Match prerequisites to actual nodes
adjacency = {n.name: set() for n in nodes}
for n in nodes:
if n.options.prerequisites is not None:
for pr in n.options.prerequisites:
pr_name = pr if isinstance(pr, str) else getattr(pr, 'name', None)
if not pr_name or pr_name not in name_to_node:
raise MissingPrerequisiteError(
f"Prerequisite '{pr_name}' for phase '{n.name}' not found in"
' PhaseGraph.'
)
adjacency[n.name].add(pr_name)

# Perform topological sort using Kahn's algorithm/DFS cycle detection
visited = set()
temp_marked = set()
sorted_names = []

def _visit(node_name: str):
if node_name in temp_marked:
raise CyclicDependencyError(f"Cycle detected involving '{node_name}'")
if node_name in visited:
return
temp_marked.add(node_name)
for prereq_name in adjacency[node_name]:
_visit(prereq_name)
temp_marked.remove(node_name)
visited.add(node_name)
sorted_names.append(node_name)

for n in nodes:
if n.name not in visited:
_visit(n.name)

return [name_to_node[name] for name in sorted_names]

def _asdict(self) -> Dict[Text, Any]:
return {
'name': self.name,
'nodes': [n._asdict() for n in self.nodes],
}

def with_args(self, **kwargs: Any) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.with_args(**kwargs) for n in self.nodes),
name=util.format_string(self.name, kwargs),
)

def with_plugs(self, **subplugs: Type[base_plugs.BasePlug]) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes),
name=util.format_string(self.name, subplugs),
)

def load_code_info(self) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.load_code_info() for n in self.nodes),
name=self.name,
)

def apply_to_all_phases(
self,
func: Callable[
[phase_descriptor.PhaseDescriptor], phase_descriptor.PhaseDescriptor
],
) -> 'PhaseGraph':
return attr.evolve(
self,
nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes),
name=self.name,
)

def filter_by_type(self, node_cls: Type[Any]) -> Iterator[Any]:
for node in self.nodes:
if isinstance(node, node_cls):
yield node
if isinstance(node, phase_collections.PhaseCollectionNode):
for sub_n in node.filter_by_type(node_cls):
yield sub_n
85 changes: 84 additions & 1 deletion openhtf/core/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
"""TestExecutor executes tests."""

import concurrent.futures
import contextlib
import enum
import logging
import multiprocessing
import pstats
import sys
import tempfile
import threading
import traceback
from typing import Iterator, List, Optional, Text, Type, TYPE_CHECKING
from typing import Iterator, List, Optional, TYPE_CHECKING, Text, Type

from openhtf import util
from openhtf.core import base_plugs
Expand All @@ -30,6 +32,7 @@
from openhtf.core import phase_collections
from openhtf.core import phase_descriptor
from openhtf.core import phase_executor
from openhtf.core import phase_graph
from openhtf.core import phase_group
from openhtf.core import phase_nodes
from openhtf.core import test_record
Expand Down Expand Up @@ -602,6 +605,84 @@ def _execute_phase_group(self, group: phase_group.PhaseGroup,
teardown_ret = _ExecutorReturn.CONTINUE
return _more_critical(main_ret, teardown_ret)

def _execute_phase_graph(
self,
graph: phase_graph.PhaseGraph,
subtest_rec: Optional[test_record.SubtestRecord],
in_teardown: bool,
) -> _ExecutorReturn:
"""Executes the phases in a phase graph concurrently according to DAG ordering."""

if graph.name:
self.logger.debug('Entering PhaseGraph %s', graph.name)

nodes = list(graph.nodes)

completed_phases = set()
failed_phases = set()
running_futures = {}

with concurrent.futures.ThreadPoolExecutor(
max_workers=min(32, (multiprocessing.cpu_count() or 1) + 4)
) as pool:
while len(completed_phases) + len(failed_phases) < len(nodes):
if self._abort.is_set() or self._full_abort.is_set():
for fut in running_futures:
fut.cancel()
return _ExecutorReturn.TERMINAL

# Submit any unblocked and un-scheduled phase
made_progress = False
for node in nodes:
if (
node.name in completed_phases
or node.name in failed_phases
or node.name in running_futures.values()
):
continue

prerequisite_satisfied = True
if node.options.prerequisites:
for pr in node.options.prerequisites:
pr_name = pr if isinstance(pr, str) else getattr(pr, 'name', None)
if pr_name not in completed_phases:
prerequisite_satisfied = False
break

if prerequisite_satisfied:
self.running_test_state.add_concurrent_node(id(node))
fut = pool.submit(
self._execute_phase, node, subtest_rec, in_teardown
)
running_futures[fut] = node.name
made_progress = True

if not running_futures and not made_progress:
# Blocked completely / cycle or missing prerequisites
return _ExecutorReturn.TERMINAL

# Wait for at least one currently running future to complete
done, _ = concurrent.futures.wait(
running_futures.keys(),
return_when=concurrent.futures.FIRST_COMPLETED,
)

for fut in done:
p_name = running_futures.pop(fut)
try:
res = fut.result()
if res == _ExecutorReturn.TERMINAL:
failed_phases.add(p_name)
return _ExecutorReturn.TERMINAL
else:
completed_phases.add(p_name)
except Exception: # pylint: disable=broad-except
self.logger.exception('Phase worker thread raised an exception.')
failed_phases.add(p_name)
return _ExecutorReturn.TERMINAL

return _ExecutorReturn.CONTINUE

def _execute_node(self, node: phase_nodes.PhaseNode,
subtest_rec: Optional[test_record.SubtestRecord],
in_teardown: bool) -> _ExecutorReturn:
Expand All @@ -613,6 +694,8 @@ def _execute_node(self, node: phase_nodes.PhaseNode,
return self._execute_sequence(node, subtest_rec, in_teardown)
if isinstance(node, phase_group.PhaseGroup):
return self._execute_phase_group(node, subtest_rec, in_teardown)
if isinstance(node, phase_graph.PhaseGraph):
return self._execute_phase_graph(node, subtest_rec, in_teardown)
if isinstance(node, phase_descriptor.PhaseDescriptor):
return self._execute_phase(node, subtest_rec, in_teardown)
if isinstance(node, phase_branches.Checkpoint):
Expand Down
Loading
Loading