Skip to content

Commit e147aaa

Browse files
authored
Merge pull request #2873 from devitocodes/hotfix-degenerating-indices
compiler: Tweak use of degenerating_indices when determining iteration directions
2 parents 805394b + 5326df0 commit e147aaa

5 files changed

Lines changed: 83 additions & 24 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,16 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
170170
candidates | known_break)
171171

172172
# Compute iteration direction
173-
idir = {d: Backward for d in candidates if d.root in scope.d_anti.cause}
173+
# When checking for iteration direction, the user may have specified an LHS
174+
# preceding the RHS, implying backward iteration, even if there is no strict
175+
# reason that this iteration would need to run backward. Check if there is a
176+
# user-specified backward iteration before defaulting to forward to avoid a
177+
# gotcha by using the logical d_anti here.
178+
idir = {d: Backward for d in candidates
179+
if d.root in scope.d_anti_logical.cause}
174180
if maybe_break:
175181
idir.update({d: Forward for d in candidates if d.root in scope.d_flow.cause})
182+
# Default to forward for remaining dimensions
176183
idir.update({d: Forward for d in candidates if d not in idir})
177184

178185
# Enforce iteration direction on each Cluster

devito/ir/support/basic.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,17 @@ def lex_le(self, other):
320320
def lex_lt(self, other):
321321
return self.timestamp < other.timestamp
322322

323-
def distance(self, other):
323+
def distance(self, other, logical=False):
324324
"""
325325
Compute the distance from ``self`` to ``other``.
326326
327327
Parameters
328328
----------
329329
other : TimedAccess
330330
The TimedAccess w.r.t. which the distance is computed.
331+
logical : bool
332+
Compute a logical distance rather than true distance (i.e. ignoring
333+
degenerating indices created by size 1 buffers etc).
331334
"""
332335
if isinstance(self.access, ComponentAccess) and \
333336
isinstance(other.access, ComponentAccess) and \
@@ -392,7 +395,7 @@ def distance(self, other):
392395
# objects falls back to zero, as any other value would be
393396
# nonsensical
394397
ret.append(S.Zero)
395-
elif degenerating_indices(self[n], other[n], self.function):
398+
elif degenerating_indices(self[n], other[n], self.function, logical=logical):
396399
# Special case: `sai` and `oai` may be different symbolic objects
397400
# but they can be proved to systematically generate the same value
398401
ret.append(S.Zero)
@@ -786,6 +789,13 @@ def is_storage_related(self, dims=None):
786789
return False
787790

788791

792+
class LogicalDependence(Dependence):
793+
794+
@cached_property
795+
def distance(self):
796+
return self.source.distance(self.sink, logical=True)
797+
798+
789799
class DependenceGroup(set):
790800

791801
@cached_property
@@ -1111,20 +1121,21 @@ def d_flow(self):
11111121
return DependenceGroup(self.d_flow_gen())
11121122

11131123
@memoized_generator
1114-
def d_anti_gen(self):
1124+
def d_anti_gen(self, depcls=Dependence):
11151125
"""Generate the anti (or "write-after-read") dependences."""
11161126
for k, v in self.writes.items():
11171127
for w in v:
11181128
for r in self.reads_smart_gen(k):
11191129
if any(not rule(r, w) for rule in self.rules):
11201130
continue
11211131

1122-
dependence = Dependence(r, w)
1132+
dependence = depcls(r, w)
11231133

11241134
if dependence.is_imaginary:
11251135
continue
11261136

11271137
distance = dependence.distance
1138+
11281139
try:
11291140
is_anti = distance > 0 or (r.lex_lt(w) and distance == 0)
11301141
except TypeError:
@@ -1140,6 +1151,14 @@ def d_anti(self):
11401151
"""Anti (or "write-after-read") dependences."""
11411152
return DependenceGroup(self.d_anti_gen())
11421153

1154+
@cached_property
1155+
def d_anti_logical(self):
1156+
"""
1157+
Anti (or "write-after-read") dependences using logical rather than true
1158+
distances.
1159+
"""
1160+
return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence))
1161+
11431162
@memoized_generator
11441163
def d_output_gen(self):
11451164
"""Generate the output (or "write-after-write") dependences."""
@@ -1425,7 +1444,7 @@ def disjoint_test(e0, e1, d, it):
14251444
return not bool(i0.intersect(i1))
14261445

14271446

1428-
def degenerating_indices(i0, i1, function):
1447+
def degenerating_indices(i0, i1, function, logical=False):
14291448
"""
14301449
True if `i0` and `i1` are indices that are possibly symbolically
14311450
different, but they can be proved to systematically degenerate to the
@@ -1440,17 +1459,19 @@ def degenerating_indices(i0, i1, function):
14401459

14411460
# Case 2: SteppingDimension corresponding to buffer of size 1
14421461
# Extract dimension from both IndexAccessFunctions -> d0, d1
1443-
try:
1444-
d0 = i0.d
1445-
except AttributeError:
1446-
d0 = i0
1447-
try:
1448-
d1 = i1.d
1449-
except AttributeError:
1450-
d1 = i1
1462+
# Skipped if doing a purely logical check
1463+
if not logical:
1464+
try:
1465+
d0 = i0.d
1466+
except AttributeError:
1467+
d0 = i0
1468+
try:
1469+
d1 = i1.d
1470+
except AttributeError:
1471+
d1 = i1
14511472

1452-
with suppress(AttributeError):
1453-
if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1:
1454-
return True
1473+
with suppress(AttributeError):
1474+
if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1:
1475+
return True
14551476

14561477
return False

devito/tools/memoization.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,18 @@ class CacheInstancesMeta(type):
142142
def __init__(cls: type[InstanceType], *args) -> None: # type: ignore
143143
super().__init__(*args)
144144

145-
# Register the cached type
145+
# Register the cached type and eagerly create its cache, bound to its
146+
# own constructor. Eager initialisation avoids a bug where a subclass
147+
# would inherit (and reuse) a parent's cache via MRO lookup if the
148+
# parent happened to be instantiated first.
146149
CacheInstancesMeta._cached_types.add(cls)
150+
maxsize = cls._instance_cache_size
151+
cls._instance_cache = lru_cache(maxsize=maxsize)(
152+
super().__call__
153+
)
147154

148155
def __call__(cls: type[InstanceType], # type: ignore
149156
*args, **kwargs) -> InstanceType:
150-
if cls._instance_cache is None:
151-
maxsize = cls._instance_cache_size
152-
cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__)
153-
154157
args, kwargs = cls._preprocess_args(*args, **kwargs)
155158
return cls._instance_cache(*args, **kwargs)
156159

tests/test_dimension.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.ir.iet import (
1818
Conditional, Expression, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree
1919
)
20+
from devito.ir.support.space import Backward, Forward
2021
from devito.symbolics import INT, IntDiv, indexify, retrieve_functions
2122
from devito.types import Array, StencilDimension, Symbol
2223
from devito.types.basic import Scalar
@@ -235,6 +236,26 @@ def test_degenerate_to_zero(self):
235236

236237
assert np.all(u.data == 10)
237238

239+
@pytest.mark.parametrize('direction', ['fwd', 'bwd'])
240+
def test_buffer1_direction(self, direction):
241+
grid = Grid(shape=(10, 10))
242+
243+
u = TimeFunction(name='u', grid=grid, save=Buffer(1))
244+
245+
# Equations technically have no implied time direction as u.forward and u refer
246+
# to the same buffer slot. However, user usage of u.forward and u.backward should
247+
# be picked up by the compiler
248+
if direction == 'fwd':
249+
op = Operator(Eq(u.forward, u + 1))
250+
else:
251+
op = Operator(Eq(u.backward, u + 1))
252+
253+
# Check for time loop direction
254+
trees = retrieve_iteration_tree(op)
255+
direction = Forward if direction == 'fwd' else Backward
256+
for tree in trees:
257+
assert tree[0].direction == direction
258+
238259

239260
class TestSubDimension:
240261

tests/test_mpi.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.ir.iet import (
1818
Call, Conditional, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree
1919
)
20+
from devito.ir.support.space import Backward, Forward
2021
from devito.mpi import MPI
2122
from devito.mpi.distributed import CustomTopology
2223
from devito.mpi.routines import ComputeCall, HaloUpdateCall, HaloUpdateList, MPICall
@@ -1916,8 +1917,8 @@ def test_haloupdate_buffer1(self, mode):
19161917
(2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')),
19171918
(1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
19181919
(1, False, 'Eq(v3.backward, v2.laplace + 1)', 3, 2, ('v1', 'v2')),
1919-
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
1920-
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
1920+
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')),
1921+
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')),
19211922
])
19221923
def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
19231924
grid = Grid((65, 65, 65), topology=('*', 1, '*'))
@@ -1943,6 +1944,12 @@ def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
19431944
op = Operator(eqns)
19441945
_ = op.cfunction
19451946

1947+
# Check for time loop direction
1948+
trees = retrieve_iteration_tree(op)
1949+
direction = Forward if fwd else Backward
1950+
for tree in trees:
1951+
assert tree[0].direction == direction
1952+
19461953
calls, _ = check_halo_exchanges(op, exp0, exp1)
19471954
for i, v in enumerate(args):
19481955
assert calls[i].arguments[0] is eval(v)

0 commit comments

Comments
 (0)