Skip to content

Commit 5b59cec

Browse files
committed
mpi: Support halo exchanges from individual Bundle components
1 parent e21ccc5 commit 5b59cec

3 files changed

Lines changed: 45 additions & 23 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ class HaloLabel(Tag):
3030
class HaloSchemeEntry(EnrichedTuple):
3131

3232
__rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims')
33+
__rkwargs__ = ('bundle',)
3334

34-
def __new__(cls, loc_indices, loc_dirs, halos, dims, getters=None):
35+
def __new__(cls, loc_indices, loc_dirs, halos, dims, bundle=None, getters=None):
36+
getters = cls.__rargs__ + cls.__rkwargs__
3537
items = [frozendict(loc_indices), frozendict(loc_dirs),
36-
frozenset(halos), frozenset(dims)]
37-
kwargs = dict(zip(cls.__rargs__, items))
38-
return super().__new__(cls, *items, getters=cls.__rargs__, **kwargs)
38+
frozenset(halos), frozenset(dims), bundle]
39+
kwargs = dict(zip(getters, items))
40+
return super().__new__(cls, *items, getters=getters, **kwargs)
3941

4042
def __hash__(self):
4143
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims))
@@ -47,7 +49,8 @@ def union(self, other):
4749
exception is raised.
4850
"""
4951
if self.loc_indices != other.loc_indices or \
50-
self.loc_dirs != other.loc_dirs:
52+
self.loc_dirs != other.loc_dirs or \
53+
self.bundle is not other.bundle:
5154
raise HaloSchemeException(
5255
"Inconsistency found while building a HaloScheme"
5356
)
@@ -56,7 +59,7 @@ def union(self, other):
5659
dims = self.dims | other.dims
5760

5861
return HaloSchemeEntry(self.loc_indices, self.loc_dirs, halos, dims,
59-
getters=self.getters)
62+
bundle=self.bundle, getters=self.getters)
6063

6164

6265
Halo = namedtuple('Halo', 'dim side')
@@ -168,7 +171,7 @@ def union(self, halo_schemes):
168171
elif not v.loc_indices or hse.loc_indices == v.loc_indices:
169172
loc_indices, loc_dirs = hse.loc_indices, hse.loc_dirs
170173
else:
171-
# The `loc_dirs` must match otherwise it'd be a symptom there's
174+
# These must match otherwise it'd be a symptom there's
172175
# something horribly broken elsewhere!
173176
assert hse.loc_dirs == v.loc_dirs
174177
assert list(hse.loc_indices) == list(v.loc_indices)
@@ -185,7 +188,11 @@ def union(self, halo_schemes):
185188
halos = hse.halos | v.halos
186189
dims = hse.dims | v.dims
187190

188-
fmapper[k] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
191+
assert hse.bundle is v.bundle
192+
193+
fmapper[k] = HaloSchemeEntry(
194+
loc_indices, loc_dirs, halos, dims, bundle=hse.bundle
195+
)
189196

190197
# Compute the `honored` union
191198
for d, v in i.honored.items():
@@ -641,8 +648,12 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
641648
for i, v in rule.items():
642649
if i is f:
643650
# Yes!
644-
g = v
645-
hse = hse0
651+
if v.is_Bundle:
652+
g = f
653+
hse = hse0._rebuild(bundle=v)
654+
else:
655+
g = v
656+
hse = hse0
646657

647658
elif i.is_Indexed and i.function is f and v.is_Indexed:
648659
# Yes, but through an Indexed, hence the `loc_indices` may now

devito/mpi/routines.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
1818
IndexedPointer, Macro, cast, subs_op_args)
1919
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize,
20-
flatten, generator, is_integer, split)
20+
flatten, generator, is_integer)
2121
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
2222
CompositeObject, CustomDimension)
2323

@@ -292,19 +292,21 @@ def _make_bundles(self, hs):
292292

293293
mapper = as_mapper(halo_scheme.fmapper, lambda i: halo_scheme.fmapper[i])
294294
for hse, components in mapper.items():
295-
# We recast everything as Bags for simplicity -- worst case scenario
296-
# all Bags only have one component. Existing Bundles are preserved
297295
halo_scheme = halo_scheme.drop(components)
298-
bundles, candidates = split(tuple(components), lambda i: i.is_Bundle)
299-
for b in bundles:
300-
halo_scheme = halo_scheme.add(b, hse)
301296

297+
# Existing Bundles are preserved
298+
if hse.bundle and set(components) == set(hse.bundle.components):
299+
halo_scheme = halo_scheme.add(hse.bundle, hse)
300+
continue
301+
302+
# We recast everything else as Bags for simplicity -- worst case
303+
# scenario all Bags only have one component.
302304
try:
303-
name = "bag_%s" % "".join(f.name for f in candidates)
304-
bag = Bag(name=name, components=candidates)
305+
name = "bag_%s" % "".join(f.name for f in components)
306+
bag = Bag(name=name, components=components)
305307
halo_scheme = halo_scheme.add(bag, hse)
306308
except ValueError:
307-
for i in candidates:
309+
for i in components:
308310
name = "bag_%s" % i.name
309311
bag = Bag(name=name, components=i)
310312
halo_scheme = halo_scheme.add(bag, hse)
@@ -363,10 +365,19 @@ def _make_copy(self, f, hse, key, swap=False):
363365
else:
364366
swap = lambda i, j: (j, i)
365367
name = 'scatter%s' % key
368+
366369
if isinstance(f, Bag):
367-
for i, c in enumerate(f.components):
368-
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
370+
if hse.bundle is not None:
371+
# `f` is the only component of `hse.bundle` that is
372+
# being communicated
373+
assert f.ncomp == 1
374+
i = hse.bundle.components.index(f.c0)
375+
eqns.append(Eq(*swap(buf[[0] + bdims], hse.bundle[[i] + findices])))
376+
else:
377+
for i, c in enumerate(f.components):
378+
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
369379
else:
380+
assert f.is_Bundle
370381
for i in range(f.ncomp):
371382
eqns.append(Eq(*swap(buf[[i] + bdims], f[[i] + findices])))
372383

devito/passes/iet/mpi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def _drop_if_unwritten(iet, options=None, **kwargs):
198198
writes = {i.write for i in FindNodes(Expression).visit(iet)}
199199
mapper = {}
200200
for hs in FindNodes(HaloSpot).visit(iet):
201-
for f in hs.fmapper:
202-
if f not in writes and key(f):
201+
for f, v in hs.fmapper.items():
202+
if not writes.intersection({f, v.bundle}) and key(f):
203203
mapper[hs] = mapper.get(hs, hs.halo_scheme).drop(f)
204204

205205
# Post-process analysis

0 commit comments

Comments
 (0)