Skip to content

Commit d6e80be

Browse files
committed
compiler: Tweak use of degenerating_indices when determining iteration directions
1 parent 805394b commit d6e80be

4 files changed

Lines changed: 90 additions & 19 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,16 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
170170
candidates | known_break)
171171

172172
# Compute iteration direction
173-
idir = {d: Backward for d in candidates if d.root in scope.d_anti.cause}
173+
# When checking for iteration direction, the user may have specified an LHS
174+
# preceding the RHS, implying backward iteration, even if there is no strict
175+
# reason that this iteration would need to run backward. Check if there is a
176+
# user-specified backward iteration before defaulting to forward to avoid a
177+
# gotcha by using the logical d_anti here.
178+
idir = {d: Backward for d in candidates
179+
if d.root in scope.d_anti_logical.cause_logical}
174180
if maybe_break:
175181
idir.update({d: Forward for d in candidates if d.root in scope.d_flow.cause})
182+
# Default to forward for remaining dimensions
176183
idir.update({d: Forward for d in candidates if d not in idir})
177184

178185
# Enforce iteration direction on each Cluster

devito/ir/support/basic.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,17 @@ def lex_le(self, other):
320320
def lex_lt(self, other):
321321
return self.timestamp < other.timestamp
322322

323-
def distance(self, other):
323+
def distance(self, other, logical=False):
324324
"""
325325
Compute the distance from ``self`` to ``other``.
326326
327327
Parameters
328328
----------
329329
other : TimedAccess
330330
The TimedAccess w.r.t. which the distance is computed.
331+
logical : bool
332+
Compute a logical distance rather than true distance (i.e. ignoring
333+
degenerating indices created by size 1 buffers etc).
331334
"""
332335
if isinstance(self.access, ComponentAccess) and \
333336
isinstance(other.access, ComponentAccess) and \
@@ -392,7 +395,7 @@ def distance(self, other):
392395
# objects falls back to zero, as any other value would be
393396
# nonsensical
394397
ret.append(S.Zero)
395-
elif degenerating_indices(self[n], other[n], self.function):
398+
elif degenerating_indices(self[n], other[n], self.function, logical=logical):
396399
# Special case: `sai` and `oai` may be different symbolic objects
397400
# but they can be proved to systematically generate the same value
398401
ret.append(S.Zero)
@@ -566,6 +569,10 @@ def timestamp(self):
566569
def distance(self):
567570
return self.source.distance(self.sink)
568571

572+
@cached_property
573+
def distance_logical(self):
574+
return self.source.distance(self.sink, logical=True)
575+
569576
@cached_property
570577
def _defined_findices(self):
571578
return frozenset(flatten(i._defines for i in self.findices))
@@ -656,6 +663,19 @@ def cause(self):
656663
return i._defines
657664
return frozenset()
658665

666+
# TODO: Refactor this
667+
@cached_property
668+
def cause_logical(self):
669+
"""Return the findex causing the dependence."""
670+
for i, j in zip(self.findices, self.distance_logical, strict=False):
671+
try:
672+
if j > 0:
673+
return i._defines
674+
except TypeError:
675+
# Conservatively assume this is an offending dimension
676+
return i._defines
677+
return frozenset()
678+
659679
@cached_property
660680
def read(self):
661681
if self.is_flow:
@@ -792,6 +812,10 @@ class DependenceGroup(set):
792812
def cause(self):
793813
return frozenset().union(*[i.cause for i in self])
794814

815+
@cached_property
816+
def cause_logical(self):
817+
return frozenset().union(*[i.cause_logical for i in self])
818+
795819
@cached_property
796820
def functions(self):
797821
"""Return the DiscreteFunctions inducing a dependence."""
@@ -1111,7 +1135,7 @@ def d_flow(self):
11111135
return DependenceGroup(self.d_flow_gen())
11121136

11131137
@memoized_generator
1114-
def d_anti_gen(self):
1138+
def d_anti_gen(self, logical=False):
11151139
"""Generate the anti (or "write-after-read") dependences."""
11161140
for k, v in self.writes.items():
11171141
for w in v:
@@ -1124,7 +1148,9 @@ def d_anti_gen(self):
11241148
if dependence.is_imaginary:
11251149
continue
11261150

1127-
distance = dependence.distance
1151+
distance = dependence.distance_logical \
1152+
if logical else dependence.distance
1153+
11281154
try:
11291155
is_anti = distance > 0 or (r.lex_lt(w) and distance == 0)
11301156
except TypeError:
@@ -1140,6 +1166,14 @@ def d_anti(self):
11401166
"""Anti (or "write-after-read") dependences."""
11411167
return DependenceGroup(self.d_anti_gen())
11421168

1169+
@cached_property
1170+
def d_anti_logical(self):
1171+
"""
1172+
Anti (or "write-after-read") dependences using logical rather than true
1173+
distances.
1174+
"""
1175+
return DependenceGroup(self.d_anti_gen(logical=True))
1176+
11431177
@memoized_generator
11441178
def d_output_gen(self):
11451179
"""Generate the output (or "write-after-write") dependences."""
@@ -1425,7 +1459,7 @@ def disjoint_test(e0, e1, d, it):
14251459
return not bool(i0.intersect(i1))
14261460

14271461

1428-
def degenerating_indices(i0, i1, function):
1462+
def degenerating_indices(i0, i1, function, logical=False):
14291463
"""
14301464
True if `i0` and `i1` are indices that are possibly symbolically
14311465
different, but they can be proved to systematically degenerate to the
@@ -1440,17 +1474,19 @@ def degenerating_indices(i0, i1, function):
14401474

14411475
# Case 2: SteppingDimension corresponding to buffer of size 1
14421476
# Extract dimension from both IndexAccessFunctions -> d0, d1
1443-
try:
1444-
d0 = i0.d
1445-
except AttributeError:
1446-
d0 = i0
1447-
try:
1448-
d1 = i1.d
1449-
except AttributeError:
1450-
d1 = i1
1477+
# Skipped if doing a purely logical check
1478+
if not logical:
1479+
try:
1480+
d0 = i0.d
1481+
except AttributeError:
1482+
d0 = i0
1483+
try:
1484+
d1 = i1.d
1485+
except AttributeError:
1486+
d1 = i1
14511487

1452-
with suppress(AttributeError):
1453-
if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1:
1454-
return True
1488+
with suppress(AttributeError):
1489+
if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1:
1490+
return True
14551491

14561492
return False

tests/test_dimension.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.ir.iet import (
1818
Conditional, Expression, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree
1919
)
20+
from devito.ir.support.space import Backward, Forward
2021
from devito.symbolics import INT, IntDiv, indexify, retrieve_functions
2122
from devito.types import Array, StencilDimension, Symbol
2223
from devito.types.basic import Scalar
@@ -235,6 +236,26 @@ def test_degenerate_to_zero(self):
235236

236237
assert np.all(u.data == 10)
237238

239+
@pytest.mark.parametrize('direction', ['fwd', 'bwd'])
240+
def test_buffer1_direction(self, direction):
241+
grid = Grid(shape=(10, 10))
242+
243+
u = TimeFunction(name='u', grid=grid, save=Buffer(1))
244+
245+
# Equations technically have no implied time direction as u.forward and u refer
246+
# to the same buffer slot. However, user usage of u.forward and u.backward should
247+
# be picked up by the compiler
248+
if direction == 'fwd':
249+
op = Operator(Eq(u.forward, u + 1))
250+
else:
251+
op = Operator(Eq(u.backward, u + 1))
252+
253+
# Check for time loop direction
254+
trees = retrieve_iteration_tree(op)
255+
direction = Forward if direction == 'fwd' else Backward
256+
for tree in trees:
257+
assert tree[0].direction == direction
258+
238259

239260
class TestSubDimension:
240261

tests/test_mpi.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.ir.iet import (
1818
Call, Conditional, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree
1919
)
20+
from devito.ir.support.space import Backward, Forward
2021
from devito.mpi import MPI
2122
from devito.mpi.distributed import CustomTopology
2223
from devito.mpi.routines import ComputeCall, HaloUpdateCall, HaloUpdateList, MPICall
@@ -1916,8 +1917,8 @@ def test_haloupdate_buffer1(self, mode):
19161917
(2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')),
19171918
(1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
19181919
(1, False, 'Eq(v3.backward, v2.laplace + 1)', 3, 2, ('v1', 'v2')),
1919-
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
1920-
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
1920+
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')),
1921+
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')),
19211922
])
19221923
def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
19231924
grid = Grid((65, 65, 65), topology=('*', 1, '*'))
@@ -1943,6 +1944,12 @@ def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
19431944
op = Operator(eqns)
19441945
_ = op.cfunction
19451946

1947+
# Check for time loop direction
1948+
trees = retrieve_iteration_tree(op)
1949+
direction = Forward if fwd else Backward
1950+
for tree in trees:
1951+
assert tree[0].direction == direction
1952+
19461953
calls, _ = check_halo_exchanges(op, exp0, exp1)
19471954
for i, v in enumerate(args):
19481955
assert calls[i].arguments[0] is eval(v)

0 commit comments

Comments
 (0)