Skip to content

Commit 408fb20

Browse files
feat: add orchestration engine for GraphGen and tests
1 parent 74252ab commit 408fb20

3 files changed

Lines changed: 163 additions & 0 deletions

File tree

graphgen/engine.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
orchestration engine for GraphGen
3+
"""
4+
5+
import threading
6+
from typing import Any, Callable, List
7+
8+
9+
class Context(dict):
10+
_lock = threading.Lock()
11+
12+
def set(self, k, v):
13+
with self._lock:
14+
self[k] = v
15+
16+
def get(self, k, default=None):
17+
with self._lock:
18+
return super().get(k, default)
19+
20+
21+
class OpNode:
22+
def __init__(
23+
self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
24+
):
25+
self.name, self.deps, self.func = name, deps, func
26+
27+
28+
def op(name: str, deps: List[str] = None):
29+
def decorator(f: Callable[["OpNode", Context], Any]):
30+
return OpNode(name, deps or [], f)
31+
32+
return decorator
33+
34+
35+
class Engine:
36+
def __init__(self, max_workers: int = 4):
37+
self.max_workers = max_workers
38+
39+
def run(self, ops: List[OpNode], ctx: Context):
40+
name2op = {operation.name: operation for operation in ops}
41+
42+
# topological sort
43+
graph = {n: set(name2op[n].deps) for n in name2op}
44+
topo = []
45+
q = [n for n, d in graph.items() if not d]
46+
while q:
47+
cur = q.pop(0)
48+
topo.append(cur)
49+
for child in [c for c, d in graph.items() if cur in d]:
50+
graph[child].remove(cur)
51+
if not graph[child]:
52+
q.append(child)
53+
54+
if len(topo) != len(ops):
55+
raise ValueError(
56+
"Cyclic dependencies detected among operations."
57+
"Please check your configuration."
58+
)
59+
60+
# semaphore for max_workers
61+
sem = threading.Semaphore(self.max_workers)
62+
done = {n: threading.Event() for n in name2op}
63+
exc = {}
64+
65+
def _exec(n: str):
66+
with sem:
67+
for d in name2op[n].deps:
68+
done[d].wait()
69+
if any(d in exc for d in name2op[n].deps):
70+
exc[n] = Exception("Skipped due to failed dependencies")
71+
done[n].set()
72+
return
73+
try:
74+
name2op[n].func(name2op[n], ctx)
75+
except Exception as e: # pylint: disable=broad-except
76+
exc[n] = e
77+
done[n].set()
78+
79+
ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
80+
for t in ts:
81+
t.start()
82+
for t in ts:
83+
t.join()
84+
if exc:
85+
raise RuntimeError(f"Some operations failed: {exc}")
File renamed without changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
3+
from graphgen.engine import Context, Engine, op
4+
5+
engine = Engine(max_workers=2)
6+
7+
8+
def test_simple_dag(capsys):
9+
"""Verify the DAG A->B/C->D execution results and print order."""
10+
ctx = Context()
11+
12+
@op("A")
13+
def op_a(self, ctx):
14+
print("Running A")
15+
ctx.set("A", 1)
16+
17+
@op("B", deps=["A"])
18+
def op_b(self, ctx):
19+
print("Running B")
20+
ctx.set("B", ctx.get("A") + 1)
21+
22+
@op("C", deps=["A"])
23+
def op_c(self, ctx):
24+
print("Running C")
25+
ctx.set("C", ctx.get("A") + 2)
26+
27+
@op("D", deps=["B", "C"])
28+
def op_d(self, ctx):
29+
print("Running D")
30+
ctx.set("D", ctx.get("B") + ctx.get("C"))
31+
32+
# Explicitly list the nodes to run; avoid relying on globals().
33+
ops = [op_a, op_b, op_c, op_d]
34+
engine.run(ops, ctx)
35+
36+
# Assert final results.
37+
assert ctx["A"] == 1
38+
assert ctx["B"] == 2
39+
assert ctx["C"] == 3
40+
assert ctx["D"] == 5
41+
42+
# Assert print order: A must run before B and C; D must run after B and C.
43+
captured = capsys.readouterr().out.strip().splitlines()
44+
assert "Running A" in captured
45+
assert "Running B" in captured
46+
assert "Running C" in captured
47+
assert "Running D" in captured
48+
49+
a_idx = next(i for i, line in enumerate(captured) if "Running A" in line)
50+
b_idx = next(i for i, line in enumerate(captured) if "Running B" in line)
51+
c_idx = next(i for i, line in enumerate(captured) if "Running C" in line)
52+
d_idx = next(i for i, line in enumerate(captured) if "Running D" in line)
53+
54+
assert a_idx < b_idx
55+
assert a_idx < c_idx
56+
assert d_idx > b_idx
57+
assert d_idx > c_idx
58+
59+
60+
def test_cyclic_detection():
61+
"""A cyclic dependency should raise ValueError."""
62+
ctx = Context()
63+
64+
@op("X", deps=["Y"])
65+
def op_x(self, ctx):
66+
pass
67+
68+
@op("Y", deps=["X"])
69+
def op_y(self, ctx):
70+
pass
71+
72+
ops = [op_x, op_y]
73+
with pytest.raises(ValueError, match="Cyclic dependencies"):
74+
engine.run(ops, ctx)
75+
76+
77+
if __name__ == "__main__":
78+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)