55"""
66
77from collections import OrderedDict
8- from collections .abc import Iterable
8+ from collections .abc import Callable , Generator , Iterable , Iterator , Sequence
99from itertools import chain , groupby
10+ from typing import Any , Generic , TypeVar
1011import ctypes
1112
1213import cgen as c
@@ -58,6 +59,55 @@ def always_rebuild(self, o, *args, **kwargs):
5859 return o ._rebuild (* new_ops , ** okwargs )
5960
6061
62+ # Type variables for LazyVisitor
63+ YieldType = TypeVar ('YieldType' , covariant = True )
64+ FlagType = TypeVar ('FlagType' , covariant = True )
65+ ResultType = TypeVar ('ResultType' , covariant = True )
66+
67+ # Describes the return type of a LazyVisitor visit method which yields objects of
68+ # type YieldType and returns a FlagType (or NoneType)
69+ LazyVisit = Generator [YieldType , None , FlagType ]
70+
71+
72+ class LazyVisitor (GenericVisitor , Generic [YieldType , ResultType , FlagType ]):
73+
74+ """
75+ A generic visitor that lazily yields results instead of flattening results
76+ from children at every step. Intermediate visit methods may return a flag
77+ of type FlagType in addition to yielding results; by default, the last flag
78+ returned by a child is the one propagated.
79+
80+ Subclass-defined visit methods should be generators.
81+ """
82+
83+ def lookup_method (self , instance ) \
84+ -> Callable [..., LazyVisit [YieldType , FlagType ]]:
85+ return super ().lookup_method (instance )
86+
87+ def _visit (self , o , * args , ** kwargs ) -> LazyVisit [YieldType , FlagType ]:
88+ meth = self .lookup_method (o )
89+ flag = yield from meth (o , * args , ** kwargs )
90+ return flag
91+
92+ def _post_visit (self , ret : LazyVisit [YieldType , FlagType ]) -> ResultType :
93+ return list (ret )
94+
95+ def visit_object (self , o : object , ** kwargs ) -> LazyVisit [YieldType , FlagType ]:
96+ yield from ()
97+
98+ def visit_Node (self , o : Node , ** kwargs ) -> LazyVisit [YieldType , FlagType ]:
99+ flag = yield from self ._visit (o .children , ** kwargs )
100+ return flag
101+
102+ def visit_tuple (self , o : Sequence [Any ], ** kwargs ) -> LazyVisit [YieldType , FlagType ]:
103+ flag : FlagType = None
104+ for i in o :
105+ flag = yield from self ._visit (i , ** kwargs )
106+ return flag
107+
108+ visit_list = visit_tuple
109+
110+
61111class PrintAST (Visitor ):
62112
63113 _depth = 0
@@ -992,16 +1042,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
9921042 return ret
9931043
9941044
995- class FindSymbols (Visitor ):
996-
997- class Retval (list ):
998- def __init__ (self , * retvals ):
999- elements = filter_ordered (flatten (retvals ), key = id )
1000- super ().__init__ (elements )
1001-
1002- @classmethod
1003- def default_retval (cls ):
1004- return cls .Retval ()
1045+ class FindSymbols (LazyVisitor [Any , list [Any ], None ]):
10051046
10061047 """
10071048 Find symbols in an Iteration/Expression tree.
@@ -1020,32 +1061,32 @@ def default_retval(cls):
10201061 - `defines-aliases`: Collect all defined objects and their aliases
10211062 """
10221063
1064+ @staticmethod
10231065 def _defines_aliases (n ):
1024- retval = []
10251066 for i in n .defines :
10261067 f = i .function
10271068 if f .is_ArrayBasic :
1028- retval . extend ([ f , f .indexed ] )
1069+ yield from ( f , f .indexed )
10291070 else :
1030- retval .append (i )
1031- return tuple (retval )
1071+ yield i
10321072
1033- rules = {
1073+ RulesDict = dict [str , Callable [[Node ], Iterator [Any ]]]
1074+ rules : RulesDict = {
10341075 'symbolics' : lambda n : n .functions ,
1035- 'basics' : lambda n : [ i for i in n .expr_symbols if isinstance (i , Basic )] ,
1036- 'symbols' : lambda n : [ i for i in n .expr_symbols
1037- if isinstance (i , AbstractSymbol )] ,
1038- 'dimensions' : lambda n : [ i for i in n .expr_symbols if isinstance (i , Dimension )] ,
1039- 'indexeds' : lambda n : [ i for i in n .expr_symbols if i .is_Indexed ] ,
1040- 'indexedbases' : lambda n : [ i for i in n .expr_symbols
1041- if isinstance (i , IndexedBase )] ,
1076+ 'basics' : lambda n : ( i for i in n .expr_symbols if isinstance (i , Basic )) ,
1077+ 'symbols' : lambda n : ( i for i in n .expr_symbols
1078+ if isinstance (i , AbstractSymbol )) ,
1079+ 'dimensions' : lambda n : ( i for i in n .expr_symbols if isinstance (i , Dimension )) ,
1080+ 'indexeds' : lambda n : ( i for i in n .expr_symbols if i .is_Indexed ) ,
1081+ 'indexedbases' : lambda n : ( i for i in n .expr_symbols
1082+ if isinstance (i , IndexedBase )) ,
10421083 'writes' : lambda n : as_tuple (n .writes ),
10431084 'defines' : lambda n : as_tuple (n .defines ),
1044- 'globals' : lambda n : [ f .base for f in n .functions if f ._mem_global ] ,
1085+ 'globals' : lambda n : ( f .base for f in n .functions if f ._mem_global ) ,
10451086 'defines-aliases' : _defines_aliases
10461087 }
10471088
1048- def __init__ (self , mode = 'symbolics' ):
1089+ def __init__ (self , mode : str = 'symbolics' ) -> None :
10491090 super ().__init__ ()
10501091
10511092 modes = mode .split ('|' )
@@ -1055,33 +1096,27 @@ def __init__(self, mode='symbolics'):
10551096 self .rule = lambda n : chain (* [self .rules [mode ](n ) for mode in modes ])
10561097
10571098 def _post_visit (self , ret ):
1058- return sorted (ret , key = lambda i : str ( i ) )
1099+ return sorted (filter_ordered ( ret , key = id ), key = str )
10591100
1060- def visit_tuple (self , o ):
1061- return self .Retval (* [self ._visit (i ) for i in o ])
1062-
1063- visit_list = visit_tuple
1101+ def visit_Node (self , o : Node ) -> Iterator [Any ]:
1102+ yield from self ._visit (o .children )
1103+ yield from self .rule (o )
10641104
1065- def visit_Node (self , o ):
1066- return self .Retval (self ._visit (o .children ), self .rule (o ))
1067-
1068- def visit_ThreadedProdder (self , o ):
1105+ def visit_ThreadedProdder (self , o ) -> Iterator [Any ]:
10691106 # TODO: this handle required because ThreadedProdder suffers from the
10701107 # long-standing issue affecting all Node subclasses which rely on
10711108 # multiple inheritance
1072- return self .Retval (self ._visit (o .then_body ), self .rule (o ))
1073-
1074- def visit_Operator (self , o ):
1075- ret = self ._visit (o .body )
1076- ret .extend (flatten (self ._visit (v ) for v in o ._func_table .values ()))
1077- return self .Retval (ret , self .rule (o ))
1109+ yield from self ._visit (o .then_body )
1110+ yield from self .rule (o )
10781111
1112+ def visit_Operator (self , o ) -> Iterator [Any ]:
1113+ yield from self ._visit (o .body )
1114+ yield from self .rule (o )
1115+ for i in o ._func_table .values ():
1116+ yield from self ._visit (i )
10791117
1080- class FindNodes (Visitor ):
10811118
1082- @classmethod
1083- def default_retval (cls ):
1084- return []
1119+ class FindNodes (LazyVisitor [Node , list [Node ], None ]):
10851120
10861121 """
10871122 Find all instances of given type.
@@ -1097,126 +1132,103 @@ def default_retval(cls):
10971132 appears.
10981133 """
10991134
1100- rules = {
1135+ RulesDict = dict [str , Callable [[type , Node ], bool ]]
1136+ rules : RulesDict = {
11011137 'type' : lambda match , o : isinstance (o , match ),
11021138 'scope' : lambda match , o : match in flatten (o .children )
11031139 }
11041140
1105- def __init__ (self , match , mode = 'type' ):
1141+ def __init__ (self , match : type , mode : str = 'type' ) -> None :
11061142 super ().__init__ ()
11071143 self .match = match
11081144 self .rule = self .rules [mode ]
11091145
1110- def visit_object (self , o , ret = None ):
1111- return ret
1112-
1113- def visit_tuple (self , o , ret = None ):
1114- for i in o :
1115- ret = self ._visit (i , ret = ret )
1116- return ret
1117-
1118- visit_list = visit_tuple
1119-
1120- def visit_Node (self , o , ret = None ):
1121- if ret is None :
1122- ret = self .default_retval ()
1146+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Node ]:
11231147 if self .rule (self .match , o ):
1124- ret . append ( o )
1148+ yield o
11251149 for i in o .children :
1126- ret = self ._visit (i , ret = ret )
1127- return ret
1150+ yield from self ._visit (i , ** kwargs )
11281151
11291152
1130- class FindWithin (FindNodes ):
1131-
1132- @classmethod
1133- def default_retval (cls ):
1134- return [], False
1153+ class FindWithin (FindNodes , LazyVisitor [Node , list [Node ], bool ]):
11351154
11361155 """
11371156 Like FindNodes, but given an additional parameter `within=(start, stop)`,
11381157 it starts collecting matching nodes only after `start` is found, and stops
11391158 collecting matching nodes after `stop` is found.
11401159 """
11411160
1142- def __init__ (self , match , start , stop = None ):
1161+ def __init__ (self , match : type , start : Node , stop : Node | None = None ) -> None :
11431162 super ().__init__ (match )
11441163 self .start = start
11451164 self .stop = stop
11461165
1147- def visit (self , o , ret = None ) :
1148- found , _ = self . _visit ( o , ret = ret )
1149- return found
1166+ def visit_object (self , o : object , flag : bool = False ) -> LazyVisit [ Node , bool ] :
1167+ yield from ( )
1168+ return flag
11501169
1151- def visit_Node (self , o , ret = None ) :
1152- if ret is None :
1153- ret = self . default_retval ()
1154- found , flag = ret
1170+ def visit_tuple (self , o : Sequence [ Any ], flag : bool = False ) -> LazyVisit [ Node , bool ] :
1171+ for el in o :
1172+ # Yield results from visiting this element, and update the flag
1173+ flag = yield from self . _visit ( el , flag = flag )
11551174
1156- if o is self .start :
1157- flag = True
1175+ return flag
1176+
1177+ visit_list = visit_tuple
1178+
1179+ def visit_Node (self , o : Node , flag : bool = False ) -> LazyVisit [Node , bool ]:
1180+ flag = flag or (o is self .start )
11581181
11591182 if flag and self .rule (self .match , o ):
1160- found .append (o )
1161- for i in o .children :
1162- found , newflag = self ._visit (i , ret = (found , flag ))
1163- if flag and not newflag :
1164- return found , newflag
1165- flag = newflag
1183+ yield o
11661184
1167- if o is self .stop :
1168- flag = False
1185+ for child in o .children :
1186+ # Yield results from this child and retrieve its flag
1187+ nflag = yield from self ._visit (child , flag = flag )
11691188
1170- return found , flag
1189+ # If we started collecting outside of here and the child found a stop,
1190+ # don't visit the rest of the children
1191+ if flag and not nflag :
1192+ return False
1193+ flag = nflag
11711194
1195+ # Update the flag if we found a stop
1196+ flag &= (o is not self .stop )
11721197
1173- class FindApplications (Visitor ):
1198+ return flag
1199+
1200+
1201+ ApplicationType = TypeVar ('ApplicationType' )
1202+
1203+
1204+ class FindApplications (LazyVisitor [ApplicationType , set [ApplicationType ], None ]):
11741205
11751206 """
11761207 Find all SymPy applied functions (aka, `Application`s). The user may refine
11771208 the search by supplying a different target class.
11781209 """
11791210
1180- def __init__ (self , cls = Application ):
1211+ def __init__ (self , cls : type [ ApplicationType ] = Application ):
11811212 super ().__init__ ()
11821213 self .match = lambda i : isinstance (i , cls ) and not isinstance (i , Basic )
11831214
1184- @classmethod
1185- def default_retval (cls ):
1186- return set ()
1187-
1188- def visit_object (self , o , ** kwargs ):
1189- return self .default_retval ()
1190-
1191- def visit_tuple (self , o , ret = None ):
1192- ret = ret or self .default_retval ()
1193- for i in o :
1194- ret .update (self ._visit (i , ret = ret ))
1195- return ret
1196-
1197- def visit_Node (self , o , ret = None ):
1198- ret = ret or self .default_retval ()
1199- for i in o .children :
1200- ret .update (self ._visit (i , ret = ret ))
1201- return ret
1215+ def _post_visit (self , ret ):
1216+ return set (ret )
12021217
1203- def visit_Expression (self , o , ** kwargs ):
1204- return o .expr .find (self .match )
1218+ def visit_Expression (self , o : Expression , ** kwargs ) -> Iterator [ ApplicationType ] :
1219+ yield from o .expr .find (self .match )
12051220
1206- def visit_Iteration (self , o , ** kwargs ):
1207- ret = self ._visit (o .children ) or self .default_retval ()
1208- ret .update (o .symbolic_min .find (self .match ))
1209- ret .update (o .symbolic_max .find (self .match ))
1210- return ret
1221+ def visit_Iteration (self , o : Iteration , ** kwargs ) -> Iterator [ApplicationType ]:
1222+ yield from self ._visit (o .children )
1223+ yield from o .symbolic_min .find (self .match )
1224+ yield from o .symbolic_max .find (self .match )
12111225
1212- def visit_Call (self , o , ** kwargs ):
1213- ret = self .default_retval ()
1226+ def visit_Call (self , o : Call , ** kwargs ) -> Iterator [ApplicationType ]:
12141227 for i in o .arguments :
12151228 try :
1216- ret . update ( i .find (self .match ) )
1229+ yield from i .find (self .match )
12171230 except (AttributeError , TypeError ):
1218- ret .update (self ._visit (i , ret = ret ))
1219- return ret
1231+ yield from self ._visit (i )
12201232
12211233
12221234class IsPerfectIteration (Visitor ):
0 commit comments