Skip to content

Commit ac5097e

Browse files
authored
Merge pull request #2621 from enwask/lazy-visitors
compiler: Lazy IET visitors + Search
2 parents 77ffae7 + c67bb40 commit ac5097e

2 files changed

Lines changed: 192 additions & 196 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 130 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
"""
66

77
from collections import OrderedDict
8-
from collections.abc import Iterable
8+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
99
from itertools import chain, groupby
10+
from typing import Any, Generic, TypeVar
1011
import ctypes
1112

1213
import cgen as c
@@ -58,6 +59,55 @@ def always_rebuild(self, o, *args, **kwargs):
5859
return o._rebuild(*new_ops, **okwargs)
5960

6061

62+
# Type variables for LazyVisitor
63+
YieldType = TypeVar('YieldType', covariant=True)
64+
FlagType = TypeVar('FlagType', covariant=True)
65+
ResultType = TypeVar('ResultType', covariant=True)
66+
67+
# Describes the return type of a LazyVisitor visit method which yields objects of
68+
# type YieldType and returns a FlagType (or NoneType)
69+
LazyVisit = Generator[YieldType, None, FlagType]
70+
71+
72+
class LazyVisitor(GenericVisitor, Generic[YieldType, ResultType, FlagType]):
73+
74+
"""
75+
A generic visitor that lazily yields results instead of flattening results
76+
from children at every step. Intermediate visit methods may return a flag
77+
of type FlagType in addition to yielding results; by default, the last flag
78+
returned by a child is the one propagated.
79+
80+
Subclass-defined visit methods should be generators.
81+
"""
82+
83+
def lookup_method(self, instance) \
84+
-> Callable[..., LazyVisit[YieldType, FlagType]]:
85+
return super().lookup_method(instance)
86+
87+
def _visit(self, o, *args, **kwargs) -> LazyVisit[YieldType, FlagType]:
88+
meth = self.lookup_method(o)
89+
flag = yield from meth(o, *args, **kwargs)
90+
return flag
91+
92+
def _post_visit(self, ret: LazyVisit[YieldType, FlagType]) -> ResultType:
93+
return list(ret)
94+
95+
def visit_object(self, o: object, **kwargs) -> LazyVisit[YieldType, FlagType]:
96+
yield from ()
97+
98+
def visit_Node(self, o: Node, **kwargs) -> LazyVisit[YieldType, FlagType]:
99+
flag = yield from self._visit(o.children, **kwargs)
100+
return flag
101+
102+
def visit_tuple(self, o: Sequence[Any], **kwargs) -> LazyVisit[YieldType, FlagType]:
103+
flag: FlagType = None
104+
for i in o:
105+
flag = yield from self._visit(i, **kwargs)
106+
return flag
107+
108+
visit_list = visit_tuple
109+
110+
61111
class PrintAST(Visitor):
62112

63113
_depth = 0
@@ -992,16 +1042,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
9921042
return ret
9931043

9941044

995-
class FindSymbols(Visitor):
996-
997-
class Retval(list):
998-
def __init__(self, *retvals):
999-
elements = filter_ordered(flatten(retvals), key=id)
1000-
super().__init__(elements)
1001-
1002-
@classmethod
1003-
def default_retval(cls):
1004-
return cls.Retval()
1045+
class FindSymbols(LazyVisitor[Any, list[Any], None]):
10051046

10061047
"""
10071048
Find symbols in an Iteration/Expression tree.
@@ -1020,32 +1061,32 @@ def default_retval(cls):
10201061
- `defines-aliases`: Collect all defined objects and their aliases
10211062
"""
10221063

1064+
@staticmethod
10231065
def _defines_aliases(n):
1024-
retval = []
10251066
for i in n.defines:
10261067
f = i.function
10271068
if f.is_ArrayBasic:
1028-
retval.extend([f, f.indexed])
1069+
yield from (f, f.indexed)
10291070
else:
1030-
retval.append(i)
1031-
return tuple(retval)
1071+
yield i
10321072

1033-
rules = {
1073+
RulesDict = dict[str, Callable[[Node], Iterator[Any]]]
1074+
rules: RulesDict = {
10341075
'symbolics': lambda n: n.functions,
1035-
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
1036-
'symbols': lambda n: [i for i in n.expr_symbols
1037-
if isinstance(i, AbstractSymbol)],
1038-
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
1039-
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
1040-
'indexedbases': lambda n: [i for i in n.expr_symbols
1041-
if isinstance(i, IndexedBase)],
1076+
'basics': lambda n: (i for i in n.expr_symbols if isinstance(i, Basic)),
1077+
'symbols': lambda n: (i for i in n.expr_symbols
1078+
if isinstance(i, AbstractSymbol)),
1079+
'dimensions': lambda n: (i for i in n.expr_symbols if isinstance(i, Dimension)),
1080+
'indexeds': lambda n: (i for i in n.expr_symbols if i.is_Indexed),
1081+
'indexedbases': lambda n: (i for i in n.expr_symbols
1082+
if isinstance(i, IndexedBase)),
10421083
'writes': lambda n: as_tuple(n.writes),
10431084
'defines': lambda n: as_tuple(n.defines),
1044-
'globals': lambda n: [f.base for f in n.functions if f._mem_global],
1085+
'globals': lambda n: (f.base for f in n.functions if f._mem_global),
10451086
'defines-aliases': _defines_aliases
10461087
}
10471088

1048-
def __init__(self, mode='symbolics'):
1089+
def __init__(self, mode: str = 'symbolics') -> None:
10491090
super().__init__()
10501091

10511092
modes = mode.split('|')
@@ -1055,33 +1096,27 @@ def __init__(self, mode='symbolics'):
10551096
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes])
10561097

10571098
def _post_visit(self, ret):
1058-
return sorted(ret, key=lambda i: str(i))
1099+
return sorted(filter_ordered(ret, key=id), key=str)
10591100

1060-
def visit_tuple(self, o):
1061-
return self.Retval(*[self._visit(i) for i in o])
1062-
1063-
visit_list = visit_tuple
1101+
def visit_Node(self, o: Node) -> Iterator[Any]:
1102+
yield from self._visit(o.children)
1103+
yield from self.rule(o)
10641104

1065-
def visit_Node(self, o):
1066-
return self.Retval(self._visit(o.children), self.rule(o))
1067-
1068-
def visit_ThreadedProdder(self, o):
1105+
def visit_ThreadedProdder(self, o) -> Iterator[Any]:
10691106
# TODO: this handle required because ThreadedProdder suffers from the
10701107
# long-standing issue affecting all Node subclasses which rely on
10711108
# multiple inheritance
1072-
return self.Retval(self._visit(o.then_body), self.rule(o))
1073-
1074-
def visit_Operator(self, o):
1075-
ret = self._visit(o.body)
1076-
ret.extend(flatten(self._visit(v) for v in o._func_table.values()))
1077-
return self.Retval(ret, self.rule(o))
1109+
yield from self._visit(o.then_body)
1110+
yield from self.rule(o)
10781111

1112+
def visit_Operator(self, o) -> Iterator[Any]:
1113+
yield from self._visit(o.body)
1114+
yield from self.rule(o)
1115+
for i in o._func_table.values():
1116+
yield from self._visit(i)
10791117

1080-
class FindNodes(Visitor):
10811118

1082-
@classmethod
1083-
def default_retval(cls):
1084-
return []
1119+
class FindNodes(LazyVisitor[Node, list[Node], None]):
10851120

10861121
"""
10871122
Find all instances of given type.
@@ -1097,126 +1132,103 @@ def default_retval(cls):
10971132
appears.
10981133
"""
10991134

1100-
rules = {
1135+
RulesDict = dict[str, Callable[[type, Node], bool]]
1136+
rules: RulesDict = {
11011137
'type': lambda match, o: isinstance(o, match),
11021138
'scope': lambda match, o: match in flatten(o.children)
11031139
}
11041140

1105-
def __init__(self, match, mode='type'):
1141+
def __init__(self, match: type, mode: str = 'type') -> None:
11061142
super().__init__()
11071143
self.match = match
11081144
self.rule = self.rules[mode]
11091145

1110-
def visit_object(self, o, ret=None):
1111-
return ret
1112-
1113-
def visit_tuple(self, o, ret=None):
1114-
for i in o:
1115-
ret = self._visit(i, ret=ret)
1116-
return ret
1117-
1118-
visit_list = visit_tuple
1119-
1120-
def visit_Node(self, o, ret=None):
1121-
if ret is None:
1122-
ret = self.default_retval()
1146+
def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]:
11231147
if self.rule(self.match, o):
1124-
ret.append(o)
1148+
yield o
11251149
for i in o.children:
1126-
ret = self._visit(i, ret=ret)
1127-
return ret
1150+
yield from self._visit(i, **kwargs)
11281151

11291152

1130-
class FindWithin(FindNodes):
1131-
1132-
@classmethod
1133-
def default_retval(cls):
1134-
return [], False
1153+
class FindWithin(FindNodes, LazyVisitor[Node, list[Node], bool]):
11351154

11361155
"""
11371156
Like FindNodes, but given an additional parameter `within=(start, stop)`,
11381157
it starts collecting matching nodes only after `start` is found, and stops
11391158
collecting matching nodes after `stop` is found.
11401159
"""
11411160

1142-
def __init__(self, match, start, stop=None):
1161+
def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
11431162
super().__init__(match)
11441163
self.start = start
11451164
self.stop = stop
11461165

1147-
def visit(self, o, ret=None):
1148-
found, _ = self._visit(o, ret=ret)
1149-
return found
1166+
def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]:
1167+
yield from ()
1168+
return flag
11501169

1151-
def visit_Node(self, o, ret=None):
1152-
if ret is None:
1153-
ret = self.default_retval()
1154-
found, flag = ret
1170+
def visit_tuple(self, o: Sequence[Any], flag: bool = False) -> LazyVisit[Node, bool]:
1171+
for el in o:
1172+
# Yield results from visiting this element, and update the flag
1173+
flag = yield from self._visit(el, flag=flag)
11551174

1156-
if o is self.start:
1157-
flag = True
1175+
return flag
1176+
1177+
visit_list = visit_tuple
1178+
1179+
def visit_Node(self, o: Node, flag: bool = False) -> LazyVisit[Node, bool]:
1180+
flag = flag or (o is self.start)
11581181

11591182
if flag and self.rule(self.match, o):
1160-
found.append(o)
1161-
for i in o.children:
1162-
found, newflag = self._visit(i, ret=(found, flag))
1163-
if flag and not newflag:
1164-
return found, newflag
1165-
flag = newflag
1183+
yield o
11661184

1167-
if o is self.stop:
1168-
flag = False
1185+
for child in o.children:
1186+
# Yield results from this child and retrieve its flag
1187+
nflag = yield from self._visit(child, flag=flag)
11691188

1170-
return found, flag
1189+
# If we started collecting outside of here and the child found a stop,
1190+
# don't visit the rest of the children
1191+
if flag and not nflag:
1192+
return False
1193+
flag = nflag
11711194

1195+
# Update the flag if we found a stop
1196+
flag &= (o is not self.stop)
11721197

1173-
class FindApplications(Visitor):
1198+
return flag
1199+
1200+
1201+
ApplicationType = TypeVar('ApplicationType')
1202+
1203+
1204+
class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]):
11741205

11751206
"""
11761207
Find all SymPy applied functions (aka, `Application`s). The user may refine
11771208
the search by supplying a different target class.
11781209
"""
11791210

1180-
def __init__(self, cls=Application):
1211+
def __init__(self, cls: type[ApplicationType] = Application):
11811212
super().__init__()
11821213
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic)
11831214

1184-
@classmethod
1185-
def default_retval(cls):
1186-
return set()
1187-
1188-
def visit_object(self, o, **kwargs):
1189-
return self.default_retval()
1190-
1191-
def visit_tuple(self, o, ret=None):
1192-
ret = ret or self.default_retval()
1193-
for i in o:
1194-
ret.update(self._visit(i, ret=ret))
1195-
return ret
1196-
1197-
def visit_Node(self, o, ret=None):
1198-
ret = ret or self.default_retval()
1199-
for i in o.children:
1200-
ret.update(self._visit(i, ret=ret))
1201-
return ret
1215+
def _post_visit(self, ret):
1216+
return set(ret)
12021217

1203-
def visit_Expression(self, o, **kwargs):
1204-
return o.expr.find(self.match)
1218+
def visit_Expression(self, o: Expression, **kwargs) -> Iterator[ApplicationType]:
1219+
yield from o.expr.find(self.match)
12051220

1206-
def visit_Iteration(self, o, **kwargs):
1207-
ret = self._visit(o.children) or self.default_retval()
1208-
ret.update(o.symbolic_min.find(self.match))
1209-
ret.update(o.symbolic_max.find(self.match))
1210-
return ret
1221+
def visit_Iteration(self, o: Iteration, **kwargs) -> Iterator[ApplicationType]:
1222+
yield from self._visit(o.children)
1223+
yield from o.symbolic_min.find(self.match)
1224+
yield from o.symbolic_max.find(self.match)
12111225

1212-
def visit_Call(self, o, **kwargs):
1213-
ret = self.default_retval()
1226+
def visit_Call(self, o: Call, **kwargs) -> Iterator[ApplicationType]:
12141227
for i in o.arguments:
12151228
try:
1216-
ret.update(i.find(self.match))
1229+
yield from i.find(self.match)
12171230
except (AttributeError, TypeError):
1218-
ret.update(self._visit(i, ret=ret))
1219-
return ret
1231+
yield from self._visit(i)
12201232

12211233

12221234
class IsPerfectIteration(Visitor):

0 commit comments

Comments
 (0)