Skip to content

Commit fed3537

Browse files
committed
compiler: Add IET LazyVisitors
1 parent 77ffae7 commit fed3537

1 file changed

Lines changed: 91 additions & 94 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 91 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
"""
66

77
from collections import OrderedDict
8-
from collections.abc import Iterable
8+
from collections.abc import Callable, Iterable, Iterator, Sequence
99
from itertools import chain, groupby
10+
from typing import Any, Generic, TypeVar
1011
import ctypes
1112

1213
import cgen as c
@@ -58,6 +59,47 @@ def always_rebuild(self, o, *args, **kwargs):
5859
return o._rebuild(*new_ops, **okwargs)
5960

6061

62+
ResultType = TypeVar('ResultType')
63+
64+
65+
class LazyVisitor(GenericVisitor, Generic[ResultType]):
66+
67+
"""
68+
A generic visitor that lazily yields results instead of flattening results
69+
from children at every step.
70+
71+
Subclass-defined visit methods (and default_retval) should be generators.
72+
"""
73+
74+
@classmethod
75+
def default_retval(cls) -> Iterator[Any]:
76+
yield from ()
77+
78+
def lookup_method(self, instance) -> Callable[..., Iterator[Any]]:
79+
return super().lookup_method(instance)
80+
81+
def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
82+
"""Visit `o`."""
83+
meth = self.lookup_method(o)
84+
yield from meth(o, *args, **kwargs)
85+
86+
def _post_visit(self, ret: Iterator[Any]) -> ResultType:
87+
"""Postprocess the visitor output before returning it to the caller."""
88+
return list(ret)
89+
90+
def visit_object(self, o: object, **kwargs) -> Iterator[Any]:
91+
yield from self.default_retval()
92+
93+
def visit_Node(self, o: Node, **kwargs) -> Iterator[Any]:
94+
yield from self._visit(o.children, **kwargs)
95+
96+
def visit_tuple(self, o: Sequence[Any]) -> Iterator[Any]:
97+
for i in o:
98+
yield from self._visit(i)
99+
100+
visit_list = visit_tuple
101+
102+
61103
class PrintAST(Visitor):
62104

63105
_depth = 0
@@ -992,16 +1034,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
9921034
return ret
9931035

9941036

995-
class FindSymbols(Visitor):
996-
997-
class Retval(list):
998-
def __init__(self, *retvals):
999-
elements = filter_ordered(flatten(retvals), key=id)
1000-
super().__init__(elements)
1001-
1002-
@classmethod
1003-
def default_retval(cls):
1004-
return cls.Retval()
1037+
class FindSymbols(LazyVisitor[list[Any]]):
10051038

10061039
"""
10071040
Find symbols in an Iteration/Expression tree.
@@ -1021,31 +1054,30 @@ def default_retval(cls):
10211054
"""
10221055

10231056
def _defines_aliases(n):
1024-
retval = []
10251057
for i in n.defines:
10261058
f = i.function
10271059
if f.is_ArrayBasic:
1028-
retval.extend([f, f.indexed])
1060+
yield from (f, f.indexed)
10291061
else:
1030-
retval.append(i)
1031-
return tuple(retval)
1062+
yield i
10321063

1033-
rules = {
1064+
RulesDict = dict[str, Callable[[Node], Iterator[Any]]]
1065+
rules: RulesDict = {
10341066
'symbolics': lambda n: n.functions,
1035-
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
1036-
'symbols': lambda n: [i for i in n.expr_symbols
1037-
if isinstance(i, AbstractSymbol)],
1038-
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
1039-
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
1040-
'indexedbases': lambda n: [i for i in n.expr_symbols
1041-
if isinstance(i, IndexedBase)],
1067+
'basics': lambda n: (i for i in n.expr_symbols if isinstance(i, Basic)),
1068+
'symbols': lambda n: (i for i in n.expr_symbols
1069+
if isinstance(i, AbstractSymbol)),
1070+
'dimensions': lambda n: (i for i in n.expr_symbols if isinstance(i, Dimension)),
1071+
'indexeds': lambda n: (i for i in n.expr_symbols if i.is_Indexed),
1072+
'indexedbases': lambda n: (i for i in n.expr_symbols
1073+
if isinstance(i, IndexedBase)),
10421074
'writes': lambda n: as_tuple(n.writes),
10431075
'defines': lambda n: as_tuple(n.defines),
1044-
'globals': lambda n: [f.base for f in n.functions if f._mem_global],
1076+
'globals': lambda n: (f.base for f in n.functions if f._mem_global),
10451077
'defines-aliases': _defines_aliases
10461078
}
10471079

1048-
def __init__(self, mode='symbolics'):
1080+
def __init__(self, mode: str = 'symbolics') -> None:
10491081
super().__init__()
10501082

10511083
modes = mode.split('|')
@@ -1055,33 +1087,27 @@ def __init__(self, mode='symbolics'):
10551087
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes])
10561088

10571089
def _post_visit(self, ret):
1058-
return sorted(ret, key=lambda i: str(i))
1090+
return sorted(filter_ordered(ret, key=id), key=str)
10591091

1060-
def visit_tuple(self, o):
1061-
return self.Retval(*[self._visit(i) for i in o])
1062-
1063-
visit_list = visit_tuple
1092+
def visit_Node(self, o: Node) -> Iterator[Any]:
1093+
yield from self._visit(o.children)
1094+
yield from self.rule(o)
10641095

1065-
def visit_Node(self, o):
1066-
return self.Retval(self._visit(o.children), self.rule(o))
1067-
1068-
def visit_ThreadedProdder(self, o):
1096+
def visit_ThreadedProdder(self, o) -> Iterator[Any]:
10691097
# TODO: this handle required because ThreadedProdder suffers from the
10701098
# long-standing issue affecting all Node subclasses which rely on
10711099
# multiple inheritance
1072-
return self.Retval(self._visit(o.then_body), self.rule(o))
1073-
1074-
def visit_Operator(self, o):
1075-
ret = self._visit(o.body)
1076-
ret.extend(flatten(self._visit(v) for v in o._func_table.values()))
1077-
return self.Retval(ret, self.rule(o))
1100+
yield from self._visit(o.then_body)
1101+
yield from self.rule(o)
10781102

1103+
def visit_Operator(self, o) -> Iterator[Any]:
1104+
yield from self._visit(o.body)
1105+
yield from self.rule(o)
1106+
for i in o._func_table.values():
1107+
yield from self._visit(i)
10791108

1080-
class FindNodes(Visitor):
10811109

1082-
@classmethod
1083-
def default_retval(cls):
1084-
return []
1110+
class FindNodes(LazyVisitor[list[Node]]):
10851111

10861112
"""
10871113
Find all instances of given type.
@@ -1097,34 +1123,22 @@ def default_retval(cls):
10971123
appears.
10981124
"""
10991125

1100-
rules = {
1126+
RulesDict = dict[str, Callable[[type, Node], bool]]
1127+
rules: RulesDict = {
11011128
'type': lambda match, o: isinstance(o, match),
11021129
'scope': lambda match, o: match in flatten(o.children)
11031130
}
11041131

1105-
def __init__(self, match, mode='type'):
1132+
def __init__(self, match: type, mode: str = 'type'):
11061133
super().__init__()
11071134
self.match = match
11081135
self.rule = self.rules[mode]
11091136

1110-
def visit_object(self, o, ret=None):
1111-
return ret
1112-
1113-
def visit_tuple(self, o, ret=None):
1114-
for i in o:
1115-
ret = self._visit(i, ret=ret)
1116-
return ret
1117-
1118-
visit_list = visit_tuple
1119-
1120-
def visit_Node(self, o, ret=None):
1121-
if ret is None:
1122-
ret = self.default_retval()
1137+
def visit_Node(self, o: Node) -> Iterator[Any]:
11231138
if self.rule(self.match, o):
1124-
ret.append(o)
1139+
yield o
11251140
for i in o.children:
1126-
ret = self._visit(i, ret=ret)
1127-
return ret
1141+
yield from self._visit(i)
11281142

11291143

11301144
class FindWithin(FindNodes):
@@ -1170,53 +1184,36 @@ def visit_Node(self, o, ret=None):
11701184
return found, flag
11711185

11721186

1173-
class FindApplications(Visitor):
1187+
ApplicationType = TypeVar('ApplicationType')
1188+
11741189

1190+
class FindApplications(LazyVisitor[set[ApplicationType]]):
11751191
"""
11761192
Find all SymPy applied functions (aka, `Application`s). The user may refine
11771193
the search by supplying a different target class.
11781194
"""
11791195

1180-
def __init__(self, cls=Application):
1196+
def __init__(self, cls: type[ApplicationType] = Application):
11811197
super().__init__()
11821198
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic)
11831199

1184-
@classmethod
1185-
def default_retval(cls):
1186-
return set()
1187-
1188-
def visit_object(self, o, **kwargs):
1189-
return self.default_retval()
1190-
1191-
def visit_tuple(self, o, ret=None):
1192-
ret = ret or self.default_retval()
1193-
for i in o:
1194-
ret.update(self._visit(i, ret=ret))
1195-
return ret
1196-
1197-
def visit_Node(self, o, ret=None):
1198-
ret = ret or self.default_retval()
1199-
for i in o.children:
1200-
ret.update(self._visit(i, ret=ret))
1201-
return ret
1200+
def _post_visit(self, ret):
1201+
return set(ret)
12021202

1203-
def visit_Expression(self, o, **kwargs):
1204-
return o.expr.find(self.match)
1203+
def visit_Expression(self, o: Expression, **kwargs) -> Iterator[ApplicationType]:
1204+
yield from o.expr.find(self.match)
12051205

1206-
def visit_Iteration(self, o, **kwargs):
1207-
ret = self._visit(o.children) or self.default_retval()
1208-
ret.update(o.symbolic_min.find(self.match))
1209-
ret.update(o.symbolic_max.find(self.match))
1210-
return ret
1206+
def visit_Iteration(self, o: Iteration, **kwargs) -> Iterator[ApplicationType]:
1207+
yield from self._visit(o.children)
1208+
yield from o.symbolic_min.find(self.match)
1209+
yield from o.symbolic_max.find(self.match)
12111210

1212-
def visit_Call(self, o, **kwargs):
1213-
ret = self.default_retval()
1211+
def visit_Call(self, o: Call, **kwargs) -> Iterator[ApplicationType]:
12141212
for i in o.arguments:
12151213
try:
1216-
ret.update(i.find(self.match))
1214+
yield from i.find(self.match)
12171215
except (AttributeError, TypeError):
1218-
ret.update(self._visit(i, ret=ret))
1219-
return ret
1216+
yield from self._visit(i)
12201217

12211218

12221219
class IsPerfectIteration(Visitor):

0 commit comments

Comments
 (0)