Skip to content

Commit ae802f6

Browse files
committed
misc: LazyVisitor tweaks + lazy FindWithin
1 parent 045a3b6 commit ae802f6

1 file changed

Lines changed: 62 additions & 43 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from collections import OrderedDict
8-
from collections.abc import Callable, Iterable, Iterator, Sequence
8+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
99
from itertools import chain, groupby
1010
from typing import Any, Generic, TypeVar
1111
import ctypes
@@ -59,37 +59,51 @@ def always_rebuild(self, o, *args, **kwargs):
5959
return o._rebuild(*new_ops, **okwargs)
6060

6161

62-
ResultType = TypeVar('ResultType')
62+
# Type variables for LazyVisitor
63+
YieldType = TypeVar('YieldType', covariant=True)
64+
FlagType = TypeVar('FlagType', covariant=True)
65+
ResultType = TypeVar('ResultType', covariant=True)
6366

67+
# Describes the return type of a LazyVisitor visit method which yields objects of
68+
# type YieldType and returns a FlagType (or NoneType)
69+
LazyVisit = Generator[YieldType, None, FlagType]
6470

65-
class LazyVisitor(GenericVisitor, Generic[ResultType]):
71+
72+
class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType, FlagType]):
6673

6774
"""
6875
A generic visitor that lazily yields results instead of flattening results
69-
from children at every step.
76+
from children at every step. Intermediate visit methods may return a flag
77+
of type FlagType in addition to yielding results; by default, the last flag
78+
returned by a child is the one propagated.
7079
7180
Subclass-defined visit methods should be generators.
7281
"""
7382

74-
def lookup_method(self, instance) -> Callable[..., Iterator[Any]]:
83+
def lookup_method(self, instance) \
84+
-> Callable[..., LazyVisit[YieldType, FlagType]]:
7585
return super().lookup_method(instance)
7686

77-
def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
87+
def _visit(self, o, *args, **kwargs) -> LazyVisit[YieldType, FlagType]:
7888
meth = self.lookup_method(o)
79-
yield from meth(o, *args, **kwargs)
89+
flag = yield from meth(o, *args, **kwargs)
90+
return flag
8091

81-
def _post_visit(self, ret: Iterator[Any]) -> ResultType:
92+
def _post_visit(self, ret: LazyVisit[YieldType, FlagType]) -> ResultType:
8293
return list(ret)
8394

84-
def visit_object(self, o: object, **kwargs) -> Iterator[Any]:
95+
def visit_object(self, o: object, **kwargs) -> LazyVisit[YieldType, FlagType]:
8596
yield from ()
8697

87-
def visit_Node(self, o: Node, **kwargs) -> Iterator[Any]:
88-
yield from self._visit(o.children, **kwargs)
98+
def visit_Node(self, o: Node, **kwargs) -> LazyVisit[YieldType, FlagType]:
99+
flag = yield from self._visit(o.children, **kwargs)
100+
return flag
89101

90-
def visit_tuple(self, o: Sequence[Any]) -> Iterator[Any]:
102+
def visit_tuple(self, o: Sequence[Any], **kwargs) -> LazyVisit[YieldType, FlagType]:
103+
flag: FlagType = None
91104
for i in o:
92-
yield from self._visit(i)
105+
flag = yield from self._visit(i, **kwargs)
106+
return flag
93107

94108
visit_list = visit_tuple
95109

@@ -1028,7 +1042,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10281042
return ret
10291043

10301044

1031-
class FindSymbols(LazyVisitor[list[Any]]):
1045+
class FindSymbols(LazyVisitor[Any, list[Any], None]):
10321046

10331047
"""
10341048
Find symbols in an Iteration/Expression tree.
@@ -1102,7 +1116,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
11021116
yield from self._visit(i)
11031117

11041118

1105-
class FindNodes(LazyVisitor[list[Node]]):
1119+
class FindNodes(LazyVisitor[Node, list[Node], None]):
11061120

11071121
"""
11081122
Find all instances of given type.
@@ -1124,65 +1138,70 @@ class FindNodes(LazyVisitor[list[Node]]):
11241138
'scope': lambda match, o: match in flatten(o.children)
11251139
}
11261140

1127-
def __init__(self, match: type, mode: str = 'type'):
1141+
def __init__(self, match: type, mode: str = 'type') -> None:
11281142
super().__init__()
11291143
self.match = match
11301144
self.rule = self.rules[mode]
11311145

1132-
def visit_Node(self, o: Node) -> Iterator[Any]:
1146+
def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]:
11331147
if self.rule(self.match, o):
11341148
yield o
11351149
for i in o.children:
1136-
yield from self._visit(i)
1150+
yield from self._visit(i, **kwargs)
11371151

11381152

1139-
class FindWithin(FindNodes):
1140-
1141-
@classmethod
1142-
def default_retval(cls):
1143-
return [], False
1153+
class FindWithin(FindNodes, LazyVisitor[Node, list[Node], bool]):
11441154

11451155
"""
11461156
Like FindNodes, but given an additional parameter `within=(start, stop)`,
11471157
it starts collecting matching nodes only after `start` is found, and stops
11481158
collecting matching nodes after `stop` is found.
11491159
"""
11501160

1151-
def __init__(self, match, start, stop=None):
1161+
def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
11521162
super().__init__(match)
11531163
self.start = start
11541164
self.stop = stop
11551165

1156-
def visit(self, o, ret=None):
1157-
found, _ = self._visit(o, ret=ret)
1158-
return found
1166+
def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]:
1167+
yield from ()
1168+
return flag
11591169

1160-
def visit_Node(self, o, ret=None):
1161-
if ret is None:
1162-
ret = self.default_retval()
1163-
found, flag = ret
1170+
def visit_tuple(self, o: Sequence[Any], flag: bool = False) -> LazyVisit[Node, bool]:
1171+
for el in o:
1172+
# Yield results from visiting this element, and update the flag
1173+
flag = yield from self._visit(el, flag=flag)
1174+
1175+
return flag
1176+
1177+
visit_list = visit_tuple
11641178

1165-
if o is self.start:
1166-
flag = True
1179+
def visit_Node(self, o: Node, flag: bool = False) -> LazyVisit[Node, bool]:
1180+
flag = flag or (o is self.start)
11671181

11681182
if flag and self.rule(self.match, o):
1169-
found.append(o)
1170-
for i in o.children:
1171-
found, newflag = self._visit(i, ret=(found, flag))
1172-
if flag and not newflag:
1173-
return found, newflag
1174-
flag = newflag
1183+
yield o
11751184

1176-
if o is self.stop:
1177-
flag = False
1185+
for child in o.children:
1186+
# Yield results from this child and retrieve its flag
1187+
nflag = yield from self._visit(child, flag=flag)
11781188

1179-
return found, flag
1189+
# If we started collecting outside of here and the child found a stop,
1190+
# don't visit the rest of the children
1191+
if flag and not nflag:
1192+
return False
1193+
flag = nflag
1194+
1195+
# Update the flag if we found a stop
1196+
flag &= (o is not self.stop)
1197+
return flag
11801198

11811199

11821200
ApplicationType = TypeVar('ApplicationType')
11831201

11841202

1185-
class FindApplications(LazyVisitor[set[ApplicationType]]):
1203+
class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]):
1204+
11861205
"""
11871206
Find all SymPy applied functions (aka, `Application`s). The user may refine
11881207
the search by supplying a different target class.

0 commit comments

Comments
 (0)