55"""
66
77from collections import OrderedDict
8- from collections .abc import Callable , Iterable , Iterator , Sequence
8+ from collections .abc import Callable , Generator , Iterable , Iterator , Sequence
99from itertools import chain , groupby
1010from typing import Any , Generic , TypeVar
1111import ctypes
@@ -59,37 +59,51 @@ def always_rebuild(self, o, *args, **kwargs):
5959 return o ._rebuild (* new_ops , ** okwargs )
6060
6161
62- ResultType = TypeVar ('ResultType' )
62+ # Type variables for LazyVisitor
63+ YieldType = TypeVar ('YieldType' , covariant = True )
64+ FlagType = TypeVar ('FlagType' , covariant = True )
65+ ResultType = TypeVar ('ResultType' , covariant = True )
6366
67+ # Describes the return type of a LazyVisitor visit method which yields objects of
68+ # type YieldType and returns a FlagType (or NoneType)
69+ LazyVisit = Generator [YieldType , None , FlagType ]
6470
65- class LazyVisitor (GenericVisitor , Generic [ResultType ]):
71+
72+ class LazyVisitor (GenericVisitor , Generic [YieldType , ResultType , FlagType ]):
6673
6774 """
6875 A generic visitor that lazily yields results instead of flattening results
69- from children at every step.
76+ from children at every step. Intermediate visit methods may return a flag
77+ of type FlagType in addition to yielding results; by default, the last flag
78+ returned by a child is the one propagated.
7079
7180 Subclass-defined visit methods should be generators.
7281 """
7382
74- def lookup_method (self , instance ) -> Callable [..., Iterator [Any ]]:
83+ def lookup_method (self , instance ) \
84+ -> Callable [..., LazyVisit [YieldType , FlagType ]]:
7585 return super ().lookup_method (instance )
7686
77- def _visit (self , o , * args , ** kwargs ) -> Iterator [ Any ]:
87+ def _visit (self , o , * args , ** kwargs ) -> LazyVisit [ YieldType , FlagType ]:
7888 meth = self .lookup_method (o )
79- yield from meth (o , * args , ** kwargs )
89+ flag = yield from meth (o , * args , ** kwargs )
90+ return flag
8091
81- def _post_visit (self , ret : Iterator [ Any ]) -> ResultType :
92+ def _post_visit (self , ret : LazyVisit [ YieldType , FlagType ]) -> ResultType :
8293 return list (ret )
8394
84- def visit_object (self , o : object , ** kwargs ) -> Iterator [ Any ]:
95+ def visit_object (self , o : object , ** kwargs ) -> LazyVisit [ YieldType , FlagType ]:
8596 yield from ()
8697
87- def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Any ]:
88- yield from self ._visit (o .children , ** kwargs )
98+ def visit_Node (self , o : Node , ** kwargs ) -> LazyVisit [YieldType , FlagType ]:
99+ flag = yield from self ._visit (o .children , ** kwargs )
100+ return flag
89101
90- def visit_tuple (self , o : Sequence [Any ]) -> Iterator [Any ]:
102+ def visit_tuple (self , o : Sequence [Any ], ** kwargs ) -> LazyVisit [YieldType , FlagType ]:
103+ flag : FlagType = None
91104 for i in o :
92- yield from self ._visit (i )
105+ flag = yield from self ._visit (i , ** kwargs )
106+ return flag
93107
94108 visit_list = visit_tuple
95109
@@ -1028,7 +1042,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10281042 return ret
10291043
10301044
1031- class FindSymbols (LazyVisitor [list [Any ]]):
1045+ class FindSymbols (LazyVisitor [Any , list [Any ], None ]):
10321046
10331047 """
10341048 Find symbols in an Iteration/Expression tree.
@@ -1102,7 +1116,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
11021116 yield from self ._visit (i )
11031117
11041118
1105- class FindNodes (LazyVisitor [list [Node ]]):
1119+ class FindNodes (LazyVisitor [Node , list [Node ], None ]):
11061120
11071121 """
11081122 Find all instances of given type.
@@ -1124,65 +1138,70 @@ class FindNodes(LazyVisitor[list[Node]]):
11241138 'scope' : lambda match , o : match in flatten (o .children )
11251139 }
11261140
1127- def __init__ (self , match : type , mode : str = 'type' ):
1141+ def __init__ (self , match : type , mode : str = 'type' ) -> None :
11281142 super ().__init__ ()
11291143 self .match = match
11301144 self .rule = self .rules [mode ]
11311145
1132- def visit_Node (self , o : Node ) -> Iterator [Any ]:
1146+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Node ]:
11331147 if self .rule (self .match , o ):
11341148 yield o
11351149 for i in o .children :
1136- yield from self ._visit (i )
1150+ yield from self ._visit (i , ** kwargs )
11371151
11381152
1139- class FindWithin (FindNodes ):
1140-
1141- @classmethod
1142- def default_retval (cls ):
1143- return [], False
1153+ class FindWithin (FindNodes , LazyVisitor [Node , list [Node ], bool ]):
11441154
11451155 """
11461156 Like FindNodes, but given an additional parameter `within=(start, stop)`,
11471157 it starts collecting matching nodes only after `start` is found, and stops
11481158 collecting matching nodes after `stop` is found.
11491159 """
11501160
1151- def __init__ (self , match , start , stop = None ):
1161+ def __init__ (self , match : type , start : Node , stop : Node | None = None ) -> None :
11521162 super ().__init__ (match )
11531163 self .start = start
11541164 self .stop = stop
11551165
1156- def visit (self , o , ret = None ) :
1157- found , _ = self . _visit ( o , ret = ret )
1158- return found
1166+ def visit_object (self , o : object , flag : bool = False ) -> LazyVisit [ Node , bool ] :
1167+ yield from ( )
1168+ return flag
11591169
1160- def visit_Node (self , o , ret = None ):
1161- if ret is None :
1162- ret = self .default_retval ()
1163- found , flag = ret
1170+ def visit_tuple (self , o : Sequence [Any ], flag : bool = False ) -> LazyVisit [Node , bool ]:
1171+ for el in o :
1172+ # Yield results from visiting this element, and update the flag
1173+ flag = yield from self ._visit (el , flag = flag )
1174+
1175+ return flag
1176+
1177+ visit_list = visit_tuple
11641178
1165- if o is self . start :
1166- flag = True
1179+ def visit_Node ( self , o : Node , flag : bool = False ) -> LazyVisit [ Node , bool ] :
1180+ flag = flag or ( o is self . start )
11671181
11681182 if flag and self .rule (self .match , o ):
1169- found .append (o )
1170- for i in o .children :
1171- found , newflag = self ._visit (i , ret = (found , flag ))
1172- if flag and not newflag :
1173- return found , newflag
1174- flag = newflag
1183+ yield o
11751184
1176- if o is self .stop :
1177- flag = False
1185+ for child in o .children :
1186+ # Yield results from this child and retrieve its flag
1187+ nflag = yield from self ._visit (child , flag = flag )
11781188
1179- return found , flag
1189+ # If we started collecting outside of here and the child found a stop,
1190+ # don't visit the rest of the children
1191+ if flag and not nflag :
1192+ return False
1193+ flag = nflag
1194+
1195+ # Update the flag if we found a stop
1196+ flag &= (o is not self .stop )
1197+ return flag
11801198
11811199
11821200ApplicationType = TypeVar ('ApplicationType' )
11831201
11841202
1185- class FindApplications (LazyVisitor [set [ApplicationType ]]):
1203+ class FindApplications (LazyVisitor [ApplicationType , set [ApplicationType ], None ]):
1204+
11861205 """
11871206 Find all SymPy applied functions (aka, `Application`s). The user may refine
11881207 the search by supplying a different target class.
0 commit comments