Skip to content

Commit 8ac4975

Browse files
authored
Merge pull request #1051 from boriel-basic/refact/use_common_visitor_for_all
Refact/use common visitor for all
2 parents c479ebf + 63f87e4 commit 8ac4975

23 files changed

Lines changed: 4095 additions & 242 deletions

src/api/optimize.py

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# --------------------------------------------------------------------
77

88
import symtable
9-
from collections.abc import Generator
9+
from collections.abc import Callable, Generator
1010
from typing import Any, NamedTuple
1111

1212
import src.api.check as chk
@@ -17,28 +17,17 @@
1717
from src.api import errmsg
1818
from src.api.config import OPTIONS
1919
from src.api.constants import CLASS, CONVENTION, SCOPE, TYPE
20-
from src.api.debug import __DEBUG__
2120
from src.api.errmsg import warning_not_used
2221
from src.ast import Ast, NodeVisitor
2322
from src.symbols import sym as symbols
2423
from src.symbols.id_ import ref
2524

2625

27-
class ToVisit(NamedTuple):
28-
"""Used just to signal an object to be
29-
traversed.
30-
"""
31-
32-
obj: symbols.SYMBOL
33-
34-
3526
class GenericVisitor(NodeVisitor):
3627
"""A slightly different visitor, that just traverses an AST, but does not return
3728
a translation of it. Used to examine the AST or do transformations
3829
"""
3930

40-
node_type = ToVisit
41-
4231
@property
4332
def O_LEVEL(self):
4433
return OPTIONS.optimization_level
@@ -58,36 +47,42 @@ def TYPE(type_):
5847
assert TYPE.is_valid(type_)
5948
return gl.SYMBOL_TABLE.basic_types[type_]
6049

61-
def visit(self, node):
62-
return super().visit(ToVisit(node))
63-
64-
def _visit(self, node: ToVisit):
65-
if node.obj is None:
66-
return None
67-
68-
__DEBUG__(f"Optimizer: Visiting node {node.obj!s}[{node.obj.token}]", 1)
69-
meth = getattr(self, f"visit_{node.obj.token}", self.generic_visit)
70-
return meth(node.obj)
71-
72-
def generic_visit(self, node: Ast) -> Generator[Ast | None, Any, None]:
73-
for i, child in enumerate(node.children):
74-
node.children[i] = yield self.visit(child)
75-
76-
yield node
77-
7850

7951
class UniqueVisitor(GenericVisitor):
8052
def __init__(self):
8153
super().__init__()
8254
self.visited = set()
8355

84-
def _visit(self, node: ToVisit):
85-
if node.obj in self.visited:
86-
return node.obj
56+
def _visit(self, node: Ast):
57+
if node in self.visited:
58+
return node
8759

88-
self.visited.add(node.obj)
60+
self.visited.add(node)
8961
return super()._visit(node)
9062

63+
def filter_inorder(
64+
self,
65+
node,
66+
filter_func: Callable[[Any], bool],
67+
child_selector: Callable[[Ast], bool] = lambda x: True,
68+
) -> Generator[Ast, None, None]:
69+
"""Visit the tree inorder, but only those that return true for filter_func and visiting children which
70+
return true for child_selector.
71+
"""
72+
visited = set()
73+
stack = [node]
74+
while stack:
75+
node = stack.pop()
76+
if node in visited:
77+
continue
78+
79+
visited.add(node)
80+
if filter_func(node):
81+
yield self.visit(node)
82+
83+
if isinstance(node, Ast) and child_selector(node):
84+
stack.extend(node.children[::-1])
85+
9186

9287
class UnreachableCodeVisitor(UniqueVisitor):
9388
"""Visitor to optimize unreachable code (and prune it)."""
@@ -107,7 +102,7 @@ def visit_FUNCTION(self, node: symbols.ID):
107102
if type_ is not None and type_ == self.TYPE(TYPE.string):
108103
node.body.append(symbols.ASM("\nld hl, 0\n", lineno, node.filename, is_sentinel=True))
109104

110-
yield (yield self.generic_visit(node))
105+
yield self.generic_visit(node)
111106

112107
def visit_BLOCK(self, node):
113108
# Remove CHKBREAK after labels
@@ -155,7 +150,7 @@ def visit_BLOCK(self, node):
155150
yield self.NOP
156151
return
157152

158-
yield (yield self.generic_visit(node))
153+
yield self.generic_visit(node)
159154

160155

161156
class FunctionGraphVisitor(UniqueVisitor):
@@ -165,6 +160,7 @@ def _get_calls_from_children(self, node: symtable.Symbol):
165160
return list(self.filter_inorder(node, lambda x: x.token in ("CALL", "FUNCCALL")))
166161

167162
def _set_children_as_accessed(self, node: symbols.SYMBOL):
163+
""" "Traverse only those"""
168164
parent = node.get_parent(symbols.FUNCDECL)
169165
if parent is None: # Global scope?
170166
for symbol in self._get_calls_from_children(node):
@@ -314,7 +310,7 @@ def visit_FUNCDECL(self, node):
314310
if self.O_LEVEL > 1 and node.params_size == node.locals_size == 0:
315311
node.entry.ref.convention = CONVENTION.fastcall
316312

317-
node.children[1] = yield ToVisit(node.entry)
313+
node.children[1] = yield self.visit(node.entry)
318314
yield node
319315

320316
def visit_LET(self, node):
@@ -370,19 +366,20 @@ def visit_RETURN(self, node):
370366
might cause infinite recursion.
371367
"""
372368
if len(node.children) == 2:
373-
node.children[1] = yield ToVisit(node.children[1])
369+
node.children[1] = yield self.visit(node.children[1])
370+
374371
yield node
375372

376373
def visit_UNARY(self, node):
377374
if node.operator == "ADDRESS":
378-
yield (yield self.visit_ADDRESS(node))
375+
yield self.visit_ADDRESS(node)
379376
else:
380-
yield (yield self.generic_visit(node))
377+
yield self.generic_visit(node)
381378

382379
def visit_IF(self, node):
383-
expr_ = yield ToVisit(node.children[0])
384-
then_ = yield ToVisit(node.children[1])
385-
else_ = (yield ToVisit(node.children[2])) if len(node.children) == 3 else self.NOP
380+
expr_ = yield self.visit(node.children[0])
381+
then_ = yield self.visit(node.children[1])
382+
else_ = (yield self.visit(node.children[2])) if len(node.children) == 3 else self.NOP
386383

387384
if self.O_LEVEL >= 1:
388385
if chk.is_null(then_, else_):
@@ -405,6 +402,7 @@ def visit_IF(self, node):
405402

406403
for i in range(len(node.children)):
407404
node.children[i] = (expr_, then_, else_)[i]
405+
408406
yield node
409407

410408
def visit_WHILE(self, node):
@@ -419,6 +417,7 @@ def visit_WHILE(self, node):
419417

420418
for i, child in enumerate((expr_, body_)):
421419
node.children[i] = child
420+
422421
yield node
423422

424423
def visit_FOR(self, node):
@@ -433,6 +432,7 @@ def visit_FOR(self, node):
433432
if from_.value > to_.value and step_.value > 0:
434433
yield self.NOP
435434
return
435+
436436
if from_.value < to_.value and step_.value < 0:
437437
yield self.NOP
438438
return
@@ -446,12 +446,6 @@ def _visit_LABEL(self, node):
446446
else:
447447
yield node
448448

449-
def generic_visit(self, node: Ast):
450-
for i, child in enumerate(node.children):
451-
node.children[i] = yield ToVisit(child)
452-
453-
yield node
454-
455449
def _check_if_any_arg_is_an_array_and_needs_lbound_or_ubound(
456450
self, params: symbols.PARAMLIST, args: symbols.ARGLIST
457451
):
@@ -502,10 +496,7 @@ class VariableVisitor(GenericVisitor):
502496
def generic_visit(self, node: Ast):
503497
if node not in VariableVisitor._visited:
504498
VariableVisitor._visited.add(node)
505-
for i in range(len(node.children)):
506-
node.children[i] = yield ToVisit(node.children[i])
507-
508-
yield node
499+
yield super().generic_visit(node)
509500

510501
def has_circular_dependency(self, var_dependency: VarDependency) -> bool:
511502
if var_dependency.dependency == VariableVisitor._original_variable:
@@ -532,7 +523,7 @@ def visit_var(entry):
532523
if entry.token != "VAR":
533524
for child in entry.children:
534525
visit_var(child)
535-
if child.token in ("FUNCTION", "LABEL", "VAR", "VARARRAY"):
526+
if child.token in {"FUNCTION", "LABEL", "VAR", "VARARRAY"}:
536527
result.add(VarDependency(parent=VariableVisitor._parent_variable, dependency=child))
537528
return
538529

src/arch/z80/visitor/builtin_translator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ class BuiltinTranslator(TranslatorVisitor):
1919

2020
REQUIRES = backend.REQUIRES
2121

22+
def __init__(self, backend: backend.Backend, parent_visitor: TranslatorVisitor):
23+
super().__init__(backend)
24+
self.parent_visitor = parent_visitor
25+
26+
def visit(self, node):
27+
return self.parent_visitor.visit(node)
28+
2229
# region STRING Functions
2330
def visit_INKEY(self, node):
2431
self.runtime_call(RuntimeLabel.INKEY, Type.string.size)
@@ -125,7 +132,7 @@ def visit_SQR(self, node):
125132
# endregion
126133

127134
def visit_LBOUND(self, node):
128-
yield node.operands[1]
135+
yield self.visit(node.operands[1])
129136
self.ic_param(gl.BOUND_TYPE, node.operands[1].t)
130137
entry = node.operands[0]
131138
if entry.scope == SCOPE.global_:
@@ -141,7 +148,7 @@ def visit_LBOUND(self, node):
141148
self.runtime_call(RuntimeLabel.LBOUND, self.TYPE(gl.BOUND_TYPE).size)
142149

143150
def visit_UBOUND(self, node):
144-
yield node.operands[1]
151+
yield self.visit(node.operands[1])
145152
self.ic_param(gl.BOUND_TYPE, node.operands[1].t)
146153
entry = node.operands[0]
147154
if entry.scope == SCOPE.global_:

src/arch/z80/visitor/function_translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ class FunctionTranslator(Translator):
2424
REQUIRES = backend.REQUIRES
2525

2626
def __init__(self, backend: Backend, function_list: list[symbols.ID]):
27+
super().__init__(backend)
2728
if function_list is None:
2829
function_list = []
29-
super().__init__(backend)
3030

3131
assert isinstance(function_list, list)
3232
assert all(x.token == "FUNCTION" for x in function_list)
@@ -115,7 +115,7 @@ def visit_FUNCTION(self, node):
115115
self.ic_lvard(local_var.offset, q)
116116

117117
for i in node.ref.body:
118-
yield i
118+
yield self.visit(i)
119119

120120
self.ic_label("%s__leave" % node.mangled)
121121

0 commit comments

Comments
 (0)