Skip to content

Commit cb454c8

Browse files
committed
compiler: Simplify handling of scalar aliases within guards
1 parent 903fdab commit cb454c8

3 files changed

Lines changed: 45 additions & 40 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
CustomDimension, Eq, Hyperplane, IncrDimension, Indexed, ModuloDimension, Size,
2626
StencilDimension, Symbol, Temp, TempArray, TempFunction
2727
)
28-
from devito.types.dimension import ConditionalDimension, SubsamplingFactor
2928
from devito.types.grid import MultiSubDimension
3029

3130
__all__ = ['cire']
@@ -149,7 +148,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
149148
# Schedule -> [Clusters]_k
150149
processed, subs = lower_schedule(schedule, meta, self.sregistry,
151150
self.opt_ftemps, self.opt_min_dtype,
152-
self.opt_minmem, nclusters=len(cgroup))
151+
self.opt_minmem)
153152

154153
# [Clusters]_k -> [Clusters]_k (optimization)
155154
if self.opt_multisubdomain:
@@ -281,6 +280,8 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None):
281280

282281
class CireInvariants(CireTransformerLegacy, Queue):
283282

283+
_q_guards_in_key = True
284+
284285
def __init__(self, sregistry, options, platform):
285286
super().__init__(sregistry, options, platform)
286287

@@ -854,7 +855,7 @@ def optimize_schedule_rotations(schedule, sregistry):
854855

855856

856857
def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
857-
opt_minmem, nclusters=1):
858+
opt_minmem):
858859
"""
859860
Turn a Schedule into a sequence of Clusters.
860861
"""
@@ -925,26 +926,15 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
925926
callback = lambda idx: obj[ # noqa: B023
926927
[i + s for i, s in zip(idx, shift, strict=True)] # noqa: B023
927928
]
928-
guards = meta.guards
929929
else:
930930
# Degenerate case: scalar expression
931931
assert writeto.size == 0
932932

933-
is_cond = any(isinstance(d, (SubsamplingFactor, ConditionalDimension))
934-
for d in pivot.free_symbols)
935-
if meta.guards and is_cond and nclusters > 1:
936-
# Scalar alias that depends on a guard, unsafe to lift out of the guard
937-
# Do not alias
938-
expression = None
939-
callback = lambda idx: uxreplace(pivot, subs) # noqa: B023
940-
else:
941-
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
942-
obj = Temp(name=name, dtype=dtype, is_const=True)
943-
expression = Eq(obj, uxreplace(pivot, subs))
933+
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
934+
obj = Temp(name=name, dtype=dtype, is_const=True)
935+
expression = Eq(obj, uxreplace(pivot, subs))
944936

945-
callback = lambda idx: obj # noqa: B023
946-
# Only keep the guard if there is no cross-cluster reuse of the scalar
947-
guards = meta.guards if nclusters == 1 else None
937+
callback = lambda idx: obj # noqa: B023
948938

949939
# Create the substitution rules for the aliasing expressions
950940
subs.update({
@@ -970,8 +960,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
970960
properties[Hyperplane(writeto.itdims)] = {SEPARABLE}
971961

972962
# Finally, build the alias Cluster
973-
if expression is not None:
974-
clusters.append(Cluster(expression, ispace, guards, properties))
963+
clusters.append(Cluster(expression, ispace, meta.guards, properties))
975964

976965
return clusters, subs
977966

devito/passes/clusters/misc.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,14 @@ def callback(self, clusters, prefix):
9797

9898
properties = c.properties.filter(key)
9999

100-
# Lifted scalar clusters cannot be guarded
101-
# as they would not be in the scope of the guarded clusters
102-
# unless the guard is for an outer dimension
103-
guards = {} if c.is_scalar and not (prefix[:-1] and c.guards) else c.guards
104-
105-
_lifted = c.rebuild(ispace=ispace, properties=properties, guards=guards)
106-
if guards and clusters[max(n-1, 0)].guards != guards and _lifted.is_scalar:
107-
# Heuristic: if the lifted Cluster has different guards than the
108-
# previous one, then we are likely to end up with a separate
109-
# Cluster, hence give up on lifting
110-
processed.append(_lifted)
100+
# If `c` is made of scalar expressions within guards, then we must keep
101+
# it close to the adjacent Clusters for correctness
102+
if c.is_scalar and c.guards:
103+
items = processed
111104
else:
112-
lifted.append(_lifted)
105+
items = lifted
106+
107+
items.append(c.rebuild(ispace=ispace, properties=properties))
113108

114109
return lifted + processed
115110

tests/test_dse.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2705,10 +2705,10 @@ def test_split_cond(self):
27052705

27062706
cond = FindNodes(Conditional).visit(op)
27072707
assert len(cond) == 3
2708-
# The alias should have been lifted out of the condition
2708+
# Each guard should have its own alias for cos(time)
27092709
assert 'float r0 = cos(time);' in str(body0(op))
27102710
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2711-
assert len(scalars) == 1
2711+
assert len(scalars) == 2
27122712

27132713
def test_split_cond_multi_alias(self):
27142714
grid = Grid((11, 11))
@@ -2728,10 +2728,11 @@ def test_split_cond_multi_alias(self):
27282728

27292729
cond = FindNodes(Conditional).visit(op)
27302730
assert len(cond) == 3
2731-
# The alias should have been lifted out of the condition
2732-
assert 'float r3 = cos(time);' in str(body0(op))
2731+
# Each guard should have its own aliases for cos(time) and sin(time)
2732+
assert 'const float r0 = sin(time) + cos(time)' in str(body0(op))
2733+
assert 'const float r1 = cos(time);' in str(body0(op))
27332734
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2734-
assert len(scalars) == 5
2735+
assert len(scalars) == 3
27352736

27362737
def test_multi_cond_no_split(self):
27372738
grid = Grid((11, 11))
@@ -2742,7 +2743,7 @@ def test_multi_cond_no_split(self):
27422743
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
27432744

27442745
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2745-
# Not hoisting the inite would have this equation split the time
2746+
# Not hoisting the init would have this equation split the time
27462747
# loop to initialize the alias for sin(time)
27472748
eq1 = Eq(u.forward, u.forward + sin(time), implicit_dims=ct2)
27482749
eq2 = Eq(u.forward, u.forward - sin(time), implicit_dims=ct)
@@ -2778,9 +2779,29 @@ def test_alias_with_conditional(self):
27782779
cond = FindNodes(Conditional).visit(op)
27792780
assert len(cond) == 3
27802781

2781-
# The alias should have been lifted out of the condition
2782+
# # Each guard should have its own alias for cos(time/ctf)
27822783
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2783-
assert not scalars
2784+
assert len(scalars) == 2
2785+
2786+
def test_scalar_alias_interp(self):
2787+
grid = Grid(shape=(11, 11))
2788+
time = grid.time_dim
2789+
2790+
t_sub = ConditionalDimension(name='t_sub', parent=time, factor=3)
2791+
2792+
f = TimeFunction(name='f', grid=grid, space_order=4)
2793+
s = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=100)
2794+
2795+
eq = Eq(f.forward, f.laplace + .002)
2796+
2797+
rec = s.interpolate(expr=f, implicit_dims=t_sub)
2798+
2799+
op = Operator(rec + [eq])
2800+
2801+
op.apply(time_M=3)
2802+
2803+
assert np.isclose(norm(f), 254292.75, atol=1e-1, rtol=0)
2804+
assert np.isclose(norm(s), 191.44644, atol=1e-1, rtol=0)
27842805

27852806

27862807
class TestIsoAcoustic:

0 commit comments

Comments
 (0)