Skip to content

Commit 2506149

Browse files
authored
Merge pull request #2645 from devitocodes/patch-gpuvect-mpi-again
compiler: Patch merge and reduction HaloScheme passes
2 parents bca0b10 + fabd033 commit 2506149

3 files changed

Lines changed: 83 additions & 15 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,18 @@ def __hash__(self):
4747
def loc_values(self):
4848
return frozenset(self.loc_indices.values())
4949

50-
def union(self, other):
50+
def merge(self, other):
5151
"""
52-
Return a new HaloSchemeEntry that is the union of this and `other`.
53-
The `loc_indices` and `loc_dirs` must be the same, otherwise an
54-
exception is raised.
52+
Return a new HaloSchemeEntry that is the result of merging `other`
53+
into `self`, irrespective of whether the `loc_indices` and `loc_dirs`
54+
are the same or not. Thus, the returned HaloSchemeEntry will have
55+
in general more halo exchanges than `self`, or exactly the same halo
56+
exchanges in the worst case).
5557
"""
56-
if self.loc_indices != other.loc_indices or \
57-
self.loc_dirs != other.loc_dirs or \
58+
if self.loc_dirs != other.loc_dirs or \
5859
self.bundle is not other.bundle:
5960
raise HaloSchemeException(
60-
"Inconsistency found while building a HaloScheme"
61+
"Inconsistency found while merging HaloSchemeEntries"
6162
)
6263

6364
halos = self.halos | other.halos
@@ -66,6 +67,22 @@ def union(self, other):
6667
return HaloSchemeEntry(self.loc_indices, self.loc_dirs, halos, dims,
6768
bundle=self.bundle, getters=self.getters)
6869

70+
def union(self, other):
71+
"""
72+
Return a new HaloSchemeEntry that is the union of this and `other`.
73+
The `loc_indices` and `loc_dirs` must be the same, otherwise an
74+
exception is raised. This is a more restrictive version of `merge`,
75+
which is used when we want to ensure that the halo exchanges are
76+
performed at the same time index, i.e., the same `loc_indices` are
77+
are expected.
78+
"""
79+
if self.loc_indices != other.loc_indices:
80+
raise HaloSchemeException(
81+
"Inconsistency found while taking the union of HaloSchemeEntries"
82+
)
83+
84+
return self.merge(other)
85+
6986

7087
Halo = namedtuple('Halo', 'dim side')
7188

@@ -481,6 +498,15 @@ def add(self, f, hse):
481498
fmapper[f] = hse
482499
return HaloScheme.build(fmapper, self.honored)
483500

501+
def merge(self, hs):
502+
"""
503+
Create a new HaloScheme that is the result of merging `hs` into `self`.
504+
"""
505+
fmapper = dict(self.fmapper)
506+
for f, hse in hs.fmapper.items():
507+
fmapper[f] = fmapper.get(f, hse).merge(hse)
508+
return HaloScheme.build(fmapper, self.honored)
509+
484510

485511
def classify(exprs, ispace):
486512
"""

devito/passes/iet/mpi.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
FindWithin, MapNodes, MapHaloSpots, Transformer,
88
retrieve_iteration_tree)
99
from devito.ir.support import PARALLEL, Scope
10-
from devito.mpi.halo_scheme import HaloScheme
1110
from devito.mpi.reduction_scheme import DistReduce
1211
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
1312
from devito.passes.iet.engine import iet_pass
@@ -42,9 +41,13 @@ def _drop_reduction_halospots(iet):
4241
# If all HaloSpot reads pertain to reductions, then the HaloSpot is useless
4342
for hs, expressions in MapNodes(HaloSpot, Expression).visit(iet).items():
4443
scope = Scope(i.expr for i in expressions)
45-
for f, v in scope.reads.items():
46-
if f in hs.fmapper and all(i.is_reduction for i in v):
47-
mapper[hs].add(f)
44+
for k, v in hs.fmapper.items():
45+
f = v.bundle or k
46+
if f not in scope.reads:
47+
continue
48+
v = scope.reads[f]
49+
if all(i.is_reduction for i in v):
50+
mapper[hs].add(k)
4851

4952
# Transform the IET introducing the "reduced" HaloSpots
5053
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
@@ -154,7 +157,7 @@ def _merge_halospots(iet):
154157

155158
# If the `loc_indices` differ, we rely on hoisting to optimize
156159
# `hsf1` out of `it`, otherwise we just drop it
157-
if hsf0.loc_values != hsf1.loc_values:
160+
if not _semantical_eq_loc_indices(hsf0, hsf1):
158161
continue
159162

160163
mapper.drop(hs1, f)
@@ -416,7 +419,7 @@ def add(self, node, hss):
416419
"""
417420
v = self.get(node)
418421
if isinstance(v, HaloSpot):
419-
hss = HaloScheme.union([v.halo_scheme, hss])
422+
hss = v.halo_scheme.merge(hss)
420423
hs = v._rebuild(halo_scheme=hss)
421424
else:
422425
hs = HaloSpot(v._rebuild(), hss)
@@ -515,3 +518,21 @@ def _is_mergeable(hsf0, hsf1, scope):
515518

516519
# Finally, check the data dependences would be satisfied
517520
return _is_iter_carried(hsf1, scope)
521+
522+
523+
def _semantical_eq_loc_indices(hsf0, hsf1):
524+
if hsf0.loc_indices != hsf1.loc_indices:
525+
return False
526+
527+
for v0, v1 in zip(hsf0.loc_values, hsf1.loc_values):
528+
if v0 is v1:
529+
continue
530+
531+
# Special case: they might be syntactically different, but semantically
532+
# equivalent, e.g., `t0` and `t1` with same modulus
533+
if v0.modulo == v1.modulo == 1:
534+
continue
535+
536+
return False
537+
538+
return True

tests/test_mpi.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,7 @@ def test_merge_and_hoist_haloupdate_if_diff_locindices(self, mode):
15801580
assert np.allclose(h.data_ro_domain[0, 5:], [4.4, 4.4, 4.4, 3.4, 3.1], rtol=R)
15811581

15821582
@pytest.mark.parallel(mode=1)
1583-
def test_merge_and_hoist_haloupdate_if_diff_locindices_v2(self, mode):
1583+
def test_merge_haloupdate_if_diff_but_equivalent_locindices(self, mode):
15841584
grid = Grid(shape=(65, 65, 65))
15851585

15861586
v1 = TimeFunction(name='v1', grid=grid, space_order=2, time_order=1,
@@ -1597,7 +1597,7 @@ def test_merge_and_hoist_haloupdate_if_diff_locindices_v2(self, mode):
15971597
op = Operator(eqns)
15981598
op.cfunction
15991599

1600-
calls, _ = check_halo_exchanges(op, 3, 2)
1600+
calls, _ = check_halo_exchanges(op, 2, 2)
16011601
for i, v in enumerate([v2, v1]):
16021602
assert calls[i].arguments[0] is v
16031603

@@ -2962,6 +2962,27 @@ def init(f, v=1):
29622962
assert np.isclose(norm(u1), 12445251.87, rtol=1e-7)
29632963
assert np.isclose(norm(v1), 147063.38, rtol=1e-7)
29642964

2965+
@pytest.mark.parallel(mode=1)
2966+
def test_interpolation_at_uforward(self, mode):
2967+
grid = Grid(shape=(10, 10, 10))
2968+
t = grid.stepping_dim
2969+
2970+
u = TimeFunction(name='u', grid=grid, space_order=2, time_order=2)
2971+
2972+
rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=10)
2973+
2974+
eqns = [Eq(u.forward, u.laplace + u.backward + 1),
2975+
rec.interpolate(expr=u.forward)]
2976+
2977+
op = Operator(eqns)
2978+
2979+
op.cfunction
2980+
2981+
calls, _ = check_halo_exchanges(op, 2, 1)
2982+
args = calls[0].arguments
2983+
assert args[-2].name == 't2'
2984+
assert args[-2].origin == t + 1
2985+
29652986

29662987
def gen_serial_norms(shape, so):
29672988
"""

0 commit comments

Comments
 (0)