66# --------------------------------------------------------------------
77
88import symtable
9- from collections .abc import Generator
9+ from collections .abc import Callable , Generator
1010from typing import Any , NamedTuple
1111
1212import src .api .check as chk
1717from src .api import errmsg
1818from src .api .config import OPTIONS
1919from src .api .constants import CLASS , CONVENTION , SCOPE , TYPE
20- from src .api .debug import __DEBUG__
2120from src .api .errmsg import warning_not_used
2221from src .ast import Ast , NodeVisitor
2322from src .symbols import sym as symbols
2423from src .symbols .id_ import ref
2524
2625
27- class ToVisit (NamedTuple ):
28- """Used just to signal an object to be
29- traversed.
30- """
31-
32- obj : symbols .SYMBOL
33-
34-
3526class GenericVisitor (NodeVisitor ):
3627 """A slightly different visitor, that just traverses an AST, but does not return
3728 a translation of it. Used to examine the AST or do transformations
3829 """
3930
40- node_type = ToVisit
41-
4231 @property
4332 def O_LEVEL (self ):
4433 return OPTIONS .optimization_level
@@ -58,36 +47,42 @@ def TYPE(type_):
5847 assert TYPE .is_valid (type_ )
5948 return gl .SYMBOL_TABLE .basic_types [type_ ]
6049
61- def visit (self , node ):
62- return super ().visit (ToVisit (node ))
63-
64- def _visit (self , node : ToVisit ):
65- if node .obj is None :
66- return None
67-
68- __DEBUG__ (f"Optimizer: Visiting node { node .obj !s} [{ node .obj .token } ]" , 1 )
69- meth = getattr (self , f"visit_{ node .obj .token } " , self .generic_visit )
70- return meth (node .obj )
71-
72- def generic_visit (self , node : Ast ) -> Generator [Ast | None , Any , None ]:
73- for i , child in enumerate (node .children ):
74- node .children [i ] = yield self .visit (child )
75-
76- yield node
77-
7850
7951class UniqueVisitor (GenericVisitor ):
8052 def __init__ (self ):
8153 super ().__init__ ()
8254 self .visited = set ()
8355
84- def _visit (self , node : ToVisit ):
85- if node . obj in self .visited :
86- return node . obj
56+ def _visit (self , node : Ast ):
57+ if node in self .visited :
58+ return node
8759
88- self .visited .add (node . obj )
60+ self .visited .add (node )
8961 return super ()._visit (node )
9062
63+ def filter_inorder (
64+ self ,
65+ node ,
66+ filter_func : Callable [[Any ], bool ],
67+ child_selector : Callable [[Ast ], bool ] = lambda x : True ,
68+ ) -> Generator [Ast , None , None ]:
69+ """Visit the tree inorder, but only those that return true for filter_func and visiting children which
70+ return true for child_selector.
71+ """
72+ visited = set ()
73+ stack = [node ]
74+ while stack :
75+ node = stack .pop ()
76+ if node in visited :
77+ continue
78+
79+ visited .add (node )
80+ if filter_func (node ):
81+ yield self .visit (node )
82+
83+ if isinstance (node , Ast ) and child_selector (node ):
84+ stack .extend (node .children [::- 1 ])
85+
9186
9287class UnreachableCodeVisitor (UniqueVisitor ):
9388 """Visitor to optimize unreachable code (and prune it)."""
@@ -107,7 +102,7 @@ def visit_FUNCTION(self, node: symbols.ID):
107102 if type_ is not None and type_ == self .TYPE (TYPE .string ):
108103 node .body .append (symbols .ASM ("\n ld hl, 0\n " , lineno , node .filename , is_sentinel = True ))
109104
110- yield ( yield self .generic_visit (node ) )
105+ yield self .generic_visit (node )
111106
112107 def visit_BLOCK (self , node ):
113108 # Remove CHKBREAK after labels
@@ -155,7 +150,7 @@ def visit_BLOCK(self, node):
155150 yield self .NOP
156151 return
157152
158- yield ( yield self .generic_visit (node ) )
153+ yield self .generic_visit (node )
159154
160155
161156class FunctionGraphVisitor (UniqueVisitor ):
@@ -165,6 +160,7 @@ def _get_calls_from_children(self, node: symtable.Symbol):
165160 return list (self .filter_inorder (node , lambda x : x .token in ("CALL" , "FUNCCALL" )))
166161
167162 def _set_children_as_accessed (self , node : symbols .SYMBOL ):
163+ """ "Traverse only those"""
168164 parent = node .get_parent (symbols .FUNCDECL )
169165 if parent is None : # Global scope?
170166 for symbol in self ._get_calls_from_children (node ):
@@ -314,7 +310,7 @@ def visit_FUNCDECL(self, node):
314310 if self .O_LEVEL > 1 and node .params_size == node .locals_size == 0 :
315311 node .entry .ref .convention = CONVENTION .fastcall
316312
317- node .children [1 ] = yield ToVisit (node .entry )
313+ node .children [1 ] = yield self . visit (node .entry )
318314 yield node
319315
320316 def visit_LET (self , node ):
@@ -370,19 +366,20 @@ def visit_RETURN(self, node):
370366 might cause infinite recursion.
371367 """
372368 if len (node .children ) == 2 :
373- node .children [1 ] = yield ToVisit (node .children [1 ])
369+ node .children [1 ] = yield self .visit (node .children [1 ])
370+
374371 yield node
375372
376373 def visit_UNARY (self , node ):
377374 if node .operator == "ADDRESS" :
378- yield ( yield self .visit_ADDRESS (node ) )
375+ yield self .visit_ADDRESS (node )
379376 else :
380- yield ( yield self .generic_visit (node ) )
377+ yield self .generic_visit (node )
381378
382379 def visit_IF (self , node ):
383- expr_ = yield ToVisit (node .children [0 ])
384- then_ = yield ToVisit (node .children [1 ])
385- else_ = (yield ToVisit (node .children [2 ])) if len (node .children ) == 3 else self .NOP
380+ expr_ = yield self . visit (node .children [0 ])
381+ then_ = yield self . visit (node .children [1 ])
382+ else_ = (yield self . visit (node .children [2 ])) if len (node .children ) == 3 else self .NOP
386383
387384 if self .O_LEVEL >= 1 :
388385 if chk .is_null (then_ , else_ ):
@@ -405,6 +402,7 @@ def visit_IF(self, node):
405402
406403 for i in range (len (node .children )):
407404 node .children [i ] = (expr_ , then_ , else_ )[i ]
405+
408406 yield node
409407
410408 def visit_WHILE (self , node ):
@@ -419,6 +417,7 @@ def visit_WHILE(self, node):
419417
420418 for i , child in enumerate ((expr_ , body_ )):
421419 node .children [i ] = child
420+
422421 yield node
423422
424423 def visit_FOR (self , node ):
@@ -433,6 +432,7 @@ def visit_FOR(self, node):
433432 if from_ .value > to_ .value and step_ .value > 0 :
434433 yield self .NOP
435434 return
435+
436436 if from_ .value < to_ .value and step_ .value < 0 :
437437 yield self .NOP
438438 return
@@ -446,12 +446,6 @@ def _visit_LABEL(self, node):
446446 else :
447447 yield node
448448
449- def generic_visit (self , node : Ast ):
450- for i , child in enumerate (node .children ):
451- node .children [i ] = yield ToVisit (child )
452-
453- yield node
454-
455449 def _check_if_any_arg_is_an_array_and_needs_lbound_or_ubound (
456450 self , params : symbols .PARAMLIST , args : symbols .ARGLIST
457451 ):
@@ -502,10 +496,7 @@ class VariableVisitor(GenericVisitor):
502496 def generic_visit (self , node : Ast ):
503497 if node not in VariableVisitor ._visited :
504498 VariableVisitor ._visited .add (node )
505- for i in range (len (node .children )):
506- node .children [i ] = yield ToVisit (node .children [i ])
507-
508- yield node
499+ yield super ().generic_visit (node )
509500
510501 def has_circular_dependency (self , var_dependency : VarDependency ) -> bool :
511502 if var_dependency .dependency == VariableVisitor ._original_variable :
@@ -532,7 +523,7 @@ def visit_var(entry):
532523 if entry .token != "VAR" :
533524 for child in entry .children :
534525 visit_var (child )
535- if child .token in ( "FUNCTION" , "LABEL" , "VAR" , "VARARRAY" ) :
526+ if child .token in { "FUNCTION" , "LABEL" , "VAR" , "VARARRAY" } :
536527 result .add (VarDependency (parent = VariableVisitor ._parent_variable , dependency = child ))
537528 return
538529
0 commit comments