Skip to content

Commit 7fb406f

Browse files
committed
compiler: Lazy search
1 parent fed3537 commit 7fb406f

1 file changed

Lines changed: 67 additions & 78 deletions

File tree

devito/symbolics/search.py

Lines changed: 67 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from collections.abc import Callable, Iterable, Iterator
2+
from itertools import chain
3+
from typing import Any, Literal
4+
5+
import numpy as np
16
import sympy
27

38
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf,
@@ -9,129 +14,113 @@
914
'retrieve_derivatives', 'search']
1015

1116

12-
class Search:
17+
Expression = sympy.Basic | np.number | int | float
18+
19+
20+
class Set(set[Expression]):
21+
@staticmethod
22+
def wrap(obj: Expression) -> set[Expression]:
23+
return {obj}
1324

14-
class Set(set):
1525

16-
@staticmethod
17-
def wrap(obj):
18-
return {obj}
26+
class List(list[Expression]):
27+
@staticmethod
28+
def wrap(obj: Expression) -> list[Expression]:
29+
return [obj]
1930

20-
class List(list):
31+
def update(self, obj: Iterable[Expression]) -> None:
32+
self.extend(obj)
2133

22-
@staticmethod
23-
def wrap(obj):
24-
return [obj]
2534

26-
def update(self, obj):
27-
return self.extend(obj)
35+
Mode = Literal['all', 'unique']
36+
modes: dict[Mode, type[List] | type[Set]] = {
37+
'all': List,
38+
'unique': Set
39+
}
2840

29-
modes = {
30-
'unique': Set,
31-
'all': List
32-
}
3341

34-
def __init__(self, query, mode, deep=False):
42+
class Search:
43+
def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> None:
3544
"""
36-
Search objects in an expression. This is much quicker than the more
37-
general SymPy's find.
45+
Search objects in an expression. This is much quicker than the more general
46+
SymPy's find.
3847
3948
Parameters
4049
----------
4150
query
4251
Any query from :mod:`queries`.
43-
mode : str
44-
Either 'unique' or 'all' (catch all instances).
4552
deep : bool, optional
4653
If True, propagate the search within an Indexed's indices. Defaults to False.
4754
"""
4855
self.query = query
49-
self.collection = self.modes[mode]
5056
self.deep = deep
5157

52-
def _next(self, expr):
58+
def _next(self, expr: Expression) -> Iterable[Expression]:
5359
if self.deep and expr.is_Indexed:
5460
return expr.indices
5561
elif q_leaf(expr):
5662
return ()
57-
else:
58-
return expr.args
63+
return expr.args
5964

60-
def dfs(self, expr):
65+
def visit_postorder(self, expr: Expression) -> Iterator[Expression]:
6166
"""
62-
Perform a DFS search.
63-
64-
Parameters
65-
----------
66-
expr : expr-like
67-
The searched expression.
67+
Visit the expression with a postorder traversal, yielding all hits.
6868
"""
69-
found = self.collection()
70-
for a in self._next(expr):
71-
found.update(self.dfs(a))
69+
for i in self._next(expr):
70+
yield from self.visit_postorder(i)
7271
if self.query(expr):
73-
found.update(self.collection.wrap(expr))
74-
return found
72+
yield expr
7573

76-
def bfs(self, expr):
74+
def visit_preorder(self, expr: Expression) -> Iterator[Expression]:
7775
"""
78-
Perform a BFS search.
79-
80-
Parameters
81-
----------
82-
expr : expr-like
83-
The searched expression.
76+
Visit the expression with a preorder traversal, yielding all hits.
8477
"""
85-
found = self.collection()
8678
if self.query(expr):
87-
found.update(self.collection.wrap(expr))
88-
for a in self._next(expr):
89-
found.update(self.bfs(a))
90-
return found
79+
yield expr
80+
for i in self._next(expr):
81+
yield from self.visit_preorder(i)
9182

92-
def bfs_first_hit(self, expr):
83+
def visit_preorder_first_hit(self, expr: Expression) -> Iterator[Expression]:
9384
"""
94-
Perform a BFS search, returning immediately when a node matches the query.
95-
96-
Parameters
97-
----------
98-
expr : expr-like
99-
The searched expression.
85+
Visit the expression in preorder and return a tuple containing the first hit,
86+
if any. This can return more than a single result, as it looks for the first
87+
hit from any branch but may find a hit in multiple branches.
10088
"""
101-
found = self.collection()
10289
if self.query(expr):
103-
found.update(self.collection.wrap(expr))
104-
return found
105-
for a in self._next(expr):
106-
found.update(self.bfs_first_hit(a))
107-
return found
90+
yield expr
91+
return
92+
for i in self._next(expr):
93+
yield from self.visit_preorder_first_hit(i)
10894

10995

110-
def search(exprs, query, mode='unique', visit='dfs', deep=False):
96+
def search(exprs: Expression | Iterable[Expression],
97+
query: type | Callable[[Any], bool],
98+
mode: Mode = 'unique',
99+
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
100+
deep: bool = False) -> List | Set:
111101
"""Interface to Search."""
112102

113-
assert mode in Search.modes, "Unknown mode"
103+
assert mode in ('all', 'unique'), "Unknown mode"
114104

115105
if isinstance(query, type):
116106
Q = lambda obj: isinstance(obj, query)
117107
else:
118108
Q = query
119109

120-
searcher = Search(Q, mode, deep)
121-
122-
found = Search.modes[mode]()
123-
for e in as_tuple(exprs):
124-
if not isinstance(e, sympy.Basic):
125-
continue
126-
127-
if visit == 'dfs':
128-
found.update(searcher.dfs(e))
129-
elif visit == 'bfs':
130-
found.update(searcher.bfs(e))
131-
elif visit == "bfs_first_hit":
132-
found.update(searcher.bfs_first_hit(e))
133-
else:
134-
raise ValueError("Unknown visit type `%s`" % visit)
110+
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
111+
# is retained in this function's parameters for backwards compatibility
112+
searcher = Search(Q, deep)
113+
if visit == 'dfs':
114+
_search = searcher.visit_postorder
115+
elif visit == 'bfs':
116+
_search = searcher.visit_preorder
117+
elif visit == 'bfs_first_hit':
118+
_search = searcher.visit_preorder_first_hit
119+
else:
120+
raise ValueError(f"Unknown visit mode '{visit}'")
121+
122+
exprs = filter(lambda e: isinstance(e, sympy.Basic), as_tuple(exprs))
123+
found = modes[mode](chain(*map(_search, exprs)))
135124

136125
return found
137126

0 commit comments

Comments
 (0)