Skip to content

Commit 045a3b6

Browse files
committed
misc: Cleanup + flake8 fixes
1 parent 7fb406f commit 045a3b6

2 files changed

Lines changed: 10 additions & 21 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,27 +68,21 @@ class LazyVisitor(GenericVisitor, Generic[ResultType]):
6868
A generic visitor that lazily yields results instead of flattening results
6969
from children at every step.
7070
71-
Subclass-defined visit methods (and default_retval) should be generators.
71+
Subclass-defined visit methods should be generators.
7272
"""
7373

74-
@classmethod
75-
def default_retval(cls) -> Iterator[Any]:
76-
yield from ()
77-
7874
def lookup_method(self, instance) -> Callable[..., Iterator[Any]]:
7975
return super().lookup_method(instance)
8076

8177
def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
82-
"""Visit `o`."""
8378
meth = self.lookup_method(o)
8479
yield from meth(o, *args, **kwargs)
8580

8681
def _post_visit(self, ret: Iterator[Any]) -> ResultType:
87-
"""Postprocess the visitor output before returning it to the caller."""
8882
return list(ret)
8983

9084
def visit_object(self, o: object, **kwargs) -> Iterator[Any]:
91-
yield from self.default_retval()
85+
yield from ()
9286

9387
def visit_Node(self, o: Node, **kwargs) -> Iterator[Any]:
9488
yield from self._visit(o.children, **kwargs)
@@ -1053,6 +1047,7 @@ class FindSymbols(LazyVisitor[list[Any]]):
10531047
- `defines-aliases`: Collect all defined objects and their aliases
10541048
"""
10551049

1050+
@staticmethod
10561051
def _defines_aliases(n):
10571052
for i in n.defines:
10581053
f = i.function

devito/symbolics/search.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,19 @@
1717
Expression = sympy.Basic | np.number | int | float
1818

1919

20-
class Set(set[Expression]):
21-
@staticmethod
22-
def wrap(obj: Expression) -> set[Expression]:
23-
return {obj}
24-
25-
2620
class List(list[Expression]):
27-
@staticmethod
28-
def wrap(obj: Expression) -> list[Expression]:
29-
return [obj]
21+
"""
22+
A list that aliases `extend` to `update` to mirror the `set` interface.
23+
"""
3024

3125
def update(self, obj: Iterable[Expression]) -> None:
3226
self.extend(obj)
3327

3428

3529
Mode = Literal['all', 'unique']
36-
modes: dict[Mode, type[List] | type[Set]] = {
30+
modes: dict[Mode, type[List] | type[set[Expression]]] = {
3731
'all': List,
38-
'unique': Set
32+
'unique': set
3933
}
4034

4135

@@ -97,7 +91,7 @@ def search(exprs: Expression | Iterable[Expression],
9791
query: type | Callable[[Any], bool],
9892
mode: Mode = 'unique',
9993
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
100-
deep: bool = False) -> List | Set:
94+
deep: bool = False) -> List | set[Expression]:
10195
"""Interface to Search."""
10296

10397
assert mode in ('all', 'unique'), "Unknown mode"
@@ -118,7 +112,7 @@ def search(exprs: Expression | Iterable[Expression],
118112
_search = searcher.visit_preorder_first_hit
119113
else:
120114
raise ValueError(f"Unknown visit mode '{visit}'")
121-
115+
122116
exprs = filter(lambda e: isinstance(e, sympy.Basic), as_tuple(exprs))
123117
found = modes[mode](chain(*map(_search, exprs)))
124118

0 commit comments

Comments
 (0)