55"""
66
77from collections import OrderedDict
8- from collections .abc import Iterable
8+ from collections .abc import Callable , Iterable , Iterator , Sequence
99from itertools import chain , groupby
10+ from typing import Any , Generic , TypeVar
1011import ctypes
1112
1213import cgen as c
@@ -58,6 +59,47 @@ def always_rebuild(self, o, *args, **kwargs):
5859 return o ._rebuild (* new_ops , ** okwargs )
5960
6061
62+ ResultType = TypeVar ('ResultType' )
63+
64+
65+ class LazyVisitor (GenericVisitor , Generic [ResultType ]):
66+
67+ """
68+ A generic visitor that lazily yields results instead of flattening results
69+ from children at every step.
70+
71+ Subclass-defined visit methods (and default_retval) should be generators.
72+ """
73+
74+ @classmethod
75+ def default_retval (cls ) -> Iterator [Any ]:
76+ yield from ()
77+
78+ def lookup_method (self , instance ) -> Callable [..., Iterator [Any ]]:
79+ return super ().lookup_method (instance )
80+
81+ def _visit (self , o , * args , ** kwargs ) -> Iterator [Any ]:
82+ """Visit `o`."""
83+ meth = self .lookup_method (o )
84+ yield from meth (o , * args , ** kwargs )
85+
86+ def _post_visit (self , ret : Iterator [Any ]) -> ResultType :
87+ """Postprocess the visitor output before returning it to the caller."""
88+ return list (ret )
89+
90+ def visit_object (self , o : object , ** kwargs ) -> Iterator [Any ]:
91+ yield from self .default_retval ()
92+
93+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Any ]:
94+ yield from self ._visit (o .children , ** kwargs )
95+
96+ def visit_tuple (self , o : Sequence [Any ]) -> Iterator [Any ]:
97+ for i in o :
98+ yield from self ._visit (i )
99+
100+ visit_list = visit_tuple
101+
102+
61103class PrintAST (Visitor ):
62104
63105 _depth = 0
@@ -992,16 +1034,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
9921034 return ret
9931035
9941036
995- class FindSymbols (Visitor ):
996-
997- class Retval (list ):
998- def __init__ (self , * retvals ):
999- elements = filter_ordered (flatten (retvals ), key = id )
1000- super ().__init__ (elements )
1001-
1002- @classmethod
1003- def default_retval (cls ):
1004- return cls .Retval ()
1037+ class FindSymbols (LazyVisitor [list [Any ]]):
10051038
10061039 """
10071040 Find symbols in an Iteration/Expression tree.
@@ -1021,31 +1054,30 @@ def default_retval(cls):
10211054 """
10221055
10231056 def _defines_aliases (n ):
1024- retval = []
10251057 for i in n .defines :
10261058 f = i .function
10271059 if f .is_ArrayBasic :
1028- retval . extend ([ f , f .indexed ] )
1060+ yield from ( f , f .indexed )
10291061 else :
1030- retval .append (i )
1031- return tuple (retval )
1062+ yield i
10321063
1033- rules = {
1064+ RulesDict = dict [str , Callable [[Node ], Iterator [Any ]]]
1065+ rules : RulesDict = {
10341066 'symbolics' : lambda n : n .functions ,
1035- 'basics' : lambda n : [ i for i in n .expr_symbols if isinstance (i , Basic )] ,
1036- 'symbols' : lambda n : [ i for i in n .expr_symbols
1037- if isinstance (i , AbstractSymbol )] ,
1038- 'dimensions' : lambda n : [ i for i in n .expr_symbols if isinstance (i , Dimension )] ,
1039- 'indexeds' : lambda n : [ i for i in n .expr_symbols if i .is_Indexed ] ,
1040- 'indexedbases' : lambda n : [ i for i in n .expr_symbols
1041- if isinstance (i , IndexedBase )] ,
1067+ 'basics' : lambda n : ( i for i in n .expr_symbols if isinstance (i , Basic )) ,
1068+ 'symbols' : lambda n : ( i for i in n .expr_symbols
1069+ if isinstance (i , AbstractSymbol )) ,
1070+ 'dimensions' : lambda n : ( i for i in n .expr_symbols if isinstance (i , Dimension )) ,
1071+ 'indexeds' : lambda n : ( i for i in n .expr_symbols if i .is_Indexed ) ,
1072+ 'indexedbases' : lambda n : ( i for i in n .expr_symbols
1073+ if isinstance (i , IndexedBase )) ,
10421074 'writes' : lambda n : as_tuple (n .writes ),
10431075 'defines' : lambda n : as_tuple (n .defines ),
1044- 'globals' : lambda n : [ f .base for f in n .functions if f ._mem_global ] ,
1076+ 'globals' : lambda n : ( f .base for f in n .functions if f ._mem_global ) ,
10451077 'defines-aliases' : _defines_aliases
10461078 }
10471079
1048- def __init__ (self , mode = 'symbolics' ):
1080+ def __init__ (self , mode : str = 'symbolics' ) -> None :
10491081 super ().__init__ ()
10501082
10511083 modes = mode .split ('|' )
@@ -1055,33 +1087,27 @@ def __init__(self, mode='symbolics'):
10551087 self .rule = lambda n : chain (* [self .rules [mode ](n ) for mode in modes ])
10561088
10571089 def _post_visit (self , ret ):
1058- return sorted (ret , key = lambda i : str ( i ) )
1090+ return sorted (filter_ordered ( ret , key = id ), key = str )
10591091
1060- def visit_tuple (self , o ):
1061- return self .Retval (* [self ._visit (i ) for i in o ])
1062-
1063- visit_list = visit_tuple
1092+ def visit_Node (self , o : Node ) -> Iterator [Any ]:
1093+ yield from self ._visit (o .children )
1094+ yield from self .rule (o )
10641095
1065- def visit_Node (self , o ):
1066- return self .Retval (self ._visit (o .children ), self .rule (o ))
1067-
1068- def visit_ThreadedProdder (self , o ):
1096+ def visit_ThreadedProdder (self , o ) -> Iterator [Any ]:
10691097 # TODO: this handle required because ThreadedProdder suffers from the
10701098 # long-standing issue affecting all Node subclasses which rely on
10711099 # multiple inheritance
1072- return self .Retval (self ._visit (o .then_body ), self .rule (o ))
1073-
1074- def visit_Operator (self , o ):
1075- ret = self ._visit (o .body )
1076- ret .extend (flatten (self ._visit (v ) for v in o ._func_table .values ()))
1077- return self .Retval (ret , self .rule (o ))
1100+ yield from self ._visit (o .then_body )
1101+ yield from self .rule (o )
10781102
1103+ def visit_Operator (self , o ) -> Iterator [Any ]:
1104+ yield from self ._visit (o .body )
1105+ yield from self .rule (o )
1106+ for i in o ._func_table .values ():
1107+ yield from self ._visit (i )
10791108
1080- class FindNodes (Visitor ):
10811109
1082- @classmethod
1083- def default_retval (cls ):
1084- return []
1110+ class FindNodes (LazyVisitor [list [Node ]]):
10851111
10861112 """
10871113 Find all instances of given type.
@@ -1097,34 +1123,22 @@ def default_retval(cls):
10971123 appears.
10981124 """
10991125
1100- rules = {
1126+ RulesDict = dict [str , Callable [[type , Node ], bool ]]
1127+ rules : RulesDict = {
11011128 'type' : lambda match , o : isinstance (o , match ),
11021129 'scope' : lambda match , o : match in flatten (o .children )
11031130 }
11041131
1105- def __init__ (self , match , mode = 'type' ):
1132+ def __init__ (self , match : type , mode : str = 'type' ):
11061133 super ().__init__ ()
11071134 self .match = match
11081135 self .rule = self .rules [mode ]
11091136
1110- def visit_object (self , o , ret = None ):
1111- return ret
1112-
1113- def visit_tuple (self , o , ret = None ):
1114- for i in o :
1115- ret = self ._visit (i , ret = ret )
1116- return ret
1117-
1118- visit_list = visit_tuple
1119-
1120- def visit_Node (self , o , ret = None ):
1121- if ret is None :
1122- ret = self .default_retval ()
1137+ def visit_Node (self , o : Node ) -> Iterator [Any ]:
11231138 if self .rule (self .match , o ):
1124- ret . append ( o )
1139+ yield o
11251140 for i in o .children :
1126- ret = self ._visit (i , ret = ret )
1127- return ret
1141+ yield from self ._visit (i )
11281142
11291143
11301144class FindWithin (FindNodes ):
@@ -1170,53 +1184,36 @@ def visit_Node(self, o, ret=None):
11701184 return found , flag
11711185
11721186
1173- class FindApplications (Visitor ):
1187+ ApplicationType = TypeVar ('ApplicationType' )
1188+
11741189
1190+ class FindApplications (LazyVisitor [set [ApplicationType ]]):
11751191 """
11761192 Find all SymPy applied functions (aka, `Application`s). The user may refine
11771193 the search by supplying a different target class.
11781194 """
11791195
1180- def __init__ (self , cls = Application ):
1196+ def __init__ (self , cls : type [ ApplicationType ] = Application ):
11811197 super ().__init__ ()
11821198 self .match = lambda i : isinstance (i , cls ) and not isinstance (i , Basic )
11831199
1184- @classmethod
1185- def default_retval (cls ):
1186- return set ()
1187-
1188- def visit_object (self , o , ** kwargs ):
1189- return self .default_retval ()
1190-
1191- def visit_tuple (self , o , ret = None ):
1192- ret = ret or self .default_retval ()
1193- for i in o :
1194- ret .update (self ._visit (i , ret = ret ))
1195- return ret
1196-
1197- def visit_Node (self , o , ret = None ):
1198- ret = ret or self .default_retval ()
1199- for i in o .children :
1200- ret .update (self ._visit (i , ret = ret ))
1201- return ret
1200+ def _post_visit (self , ret ):
1201+ return set (ret )
12021202
1203- def visit_Expression (self , o , ** kwargs ):
1204- return o .expr .find (self .match )
1203+ def visit_Expression (self , o : Expression , ** kwargs ) -> Iterator [ ApplicationType ] :
1204+ yield from o .expr .find (self .match )
12051205
1206- def visit_Iteration (self , o , ** kwargs ):
1207- ret = self ._visit (o .children ) or self .default_retval ()
1208- ret .update (o .symbolic_min .find (self .match ))
1209- ret .update (o .symbolic_max .find (self .match ))
1210- return ret
1206+ def visit_Iteration (self , o : Iteration , ** kwargs ) -> Iterator [ApplicationType ]:
1207+ yield from self ._visit (o .children )
1208+ yield from o .symbolic_min .find (self .match )
1209+ yield from o .symbolic_max .find (self .match )
12111210
1212- def visit_Call (self , o , ** kwargs ):
1213- ret = self .default_retval ()
1211+ def visit_Call (self , o : Call , ** kwargs ) -> Iterator [ApplicationType ]:
12141212 for i in o .arguments :
12151213 try :
1216- ret . update ( i .find (self .match ) )
1214+ yield from i .find (self .match )
12171215 except (AttributeError , TypeError ):
1218- ret .update (self ._visit (i , ret = ret ))
1219- return ret
1216+ yield from self ._visit (i )
12201217
12211218
12221219class IsPerfectIteration (Visitor ):
0 commit comments