Skip to content

Commit eb3b4ea

Browse files
committed
misc: add test and misc tweaks
1 parent cb454c8 commit eb3b4ea

3 files changed

Lines changed: 79 additions & 10 deletions

File tree

devito/passes/clusters/misc.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,10 @@ def callback(self, clusters, prefix):
9999

100100
# If `c` is made of scalar expressions within guards, then we must keep
101101
# it close to the adjacent Clusters for correctness
102-
if c.is_scalar and c.guards:
103-
items = processed
102+
if c.is_scalar and c.guards and ispace:
103+
processed.append(c.rebuild(ispace=ispace, properties=properties))
104104
else:
105-
items = lifted
106-
107-
items.append(c.rebuild(ispace=ispace, properties=properties))
105+
lifted.append(c.rebuild(ispace=ispace, properties=properties))
108106

109107
return lifted + processed
110108

devito/symbolics/search.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,18 @@ def visit_preorder_first_hit(self, expr: Expression) -> Iterator[Expression]:
9696

9797

9898
def search(exprs: Expression | Iterable[Expression],
99-
query: type | Callable[[Any], bool],
99+
query: type | tuple[type, ...] | Callable[[Any], bool],
100100
mode: Mode = 'unique',
101101
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
102102
deep: bool = False) -> List | set[Expression]:
103103
"""Interface to Search."""
104104

105105
assert mode in ('all', 'unique'), "Unknown mode"
106106

107-
Q = (lambda obj: isinstance(obj, query)) if isinstance(query, type) else query
107+
if isinstance(query, (type, tuple)):
108+
Q = lambda obj: isinstance(obj, query)
109+
else:
110+
Q = query
108111

109112
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
110113
# is retained in this function's parameters for backwards compatibility

tests/test_dse.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2779,7 +2779,7 @@ def test_alias_with_conditional(self):
27792779
cond = FindNodes(Conditional).visit(op)
27802780
assert len(cond) == 3
27812781

2782-
# # Each guard should have its own alias for cos(time/ctf)
2782+
# Each guard should have its own alias for cos(time/ctf)
27832783
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
27842784
assert len(scalars) == 2
27852785

@@ -2800,8 +2800,76 @@ def test_scalar_alias_interp(self):
28002800

28012801
op.apply(time_M=3)
28022802

2803-
assert np.isclose(norm(f), 254292.75, atol=1e-1, rtol=0)
2804-
assert np.isclose(norm(s), 191.44644, atol=1e-1, rtol=0)
2803+
assert np.isclose(norm(f), 254292.75, atol=0, rtol=1e-5)
2804+
assert np.isclose(norm(s), 191.44644, atol=0, rtol=1e-4)
2805+
2806+
def test_scalar_with_cond_access(self):
2807+
grid = Grid((11, 11))
2808+
time = grid.time_dim
2809+
2810+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2811+
2812+
ct = ConditionalDimension(name='ct3', parent=time, condition=Ge(time, 2))
2813+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2814+
2815+
f1 = TimeFunction(name='f1', grid=grid, save=10, time_order=0,
2816+
dimensions=(ct,), time_dim=ct, shape=(10,))
2817+
f1.data[:] = np.arange(10)
2818+
2819+
eq0 = Eq(u.forward, u + cos(f1))
2820+
eq1 = Eq(u.forward, u.forward + sin(time), implicit_dims=ct2)
2821+
eq2 = Eq(u.forward, u.forward - sin(f1))
2822+
2823+
op = Operator([eq0, eq1, eq2])
2824+
2825+
cond = FindNodes(Conditional).visit(op)
2826+
assert len(cond) == 3
2827+
2828+
# # Each guard should have its own alias for cos/sin(f1[time-2])
2829+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2830+
assert len(scalars) == 3
2831+
2832+
assert_structure(
2833+
op,
2834+
['t', 't,x,y', 't,x,y', 't,x,y'],
2835+
'txyxyxy'
2836+
)
2837+
2838+
# This would segfault without the right placement of the alias
2839+
op.apply(time_M=12)
2840+
2841+
def test_scalar_with_cond_tinvariant(self):
2842+
grid = Grid((10, 10))
2843+
time = grid.time_dim
2844+
dt = time.spacing
2845+
2846+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2847+
2848+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2849+
2850+
eq0 = Eq(u.forward, u / dt + 1)
2851+
eq1 = Eq(u.forward, u.forward + 1/dt**2, implicit_dims=ct)
2852+
2853+
op = Operator([eq0, eq1])
2854+
op(time=5, dt=1)
2855+
2856+
cond = FindNodes(Conditional).visit(op)
2857+
assert len(cond) == 1
2858+
# One for each 1/dt 1/dt**2
2859+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2860+
assert len(scalars) == 2
2861+
2862+
assert_structure(
2863+
op,
2864+
['t,x,y', 't', 't,x,y'],
2865+
'txyxy'
2866+
)
2867+
2868+
# Both aliases should be hoisted outside the time loop
2869+
assert str(body0(op).body[0]) == 'const float r0 = 1.0F/dt;'
2870+
assert not body0(op).body[0].ispace
2871+
assert str(body0(op).body[1]) == 'const float r1 = 1.0F/(dt*dt);'
2872+
assert not body0(op).body[1].ispace
28052873

28062874

28072875
class TestIsoAcoustic:

0 commit comments

Comments
 (0)