Skip to content

Commit 2e52e88

Browse files
committed
FIX FIX FIX (now generate much better)
1 parent ab67838 commit 2e52e88

2 files changed

Lines changed: 44 additions & 12 deletions

File tree

devito/mpi/routines.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
IndexedPointer, Macro, cast, subs_op_args)
1919
from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize,
2020
flatten, generator, is_integer)
21-
from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject,
22-
CompositeObject, CustomDimension)
21+
from devito.types import (Array, Bag, BundleView, Dimension, Eq, Symbol,
22+
LocalObject, CompositeObject, CustomDimension)
2323

2424
__all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry']
2525

@@ -295,8 +295,15 @@ def _make_bundles(self, hs):
295295
halo_scheme = halo_scheme.drop(components)
296296

297297
# Existing Bundles are preserved
298-
if hse.bundle and set(components) == set(hse.bundle.components):
299-
halo_scheme = halo_scheme.add(hse.bundle, hse)
298+
if hse.bundle:
299+
if set(components) == set(hse.bundle.components):
300+
halo_scheme = halo_scheme.add(hse.bundle, hse)
301+
else:
302+
name = f'bundleview_{hse.bundle.name}'
303+
bundle_view = BundleView(
304+
name=name, components=components, parent=hse.bundle
305+
)
306+
halo_scheme = halo_scheme.add(bundle_view, hse)
300307
continue
301308

302309
# We recast everything else as Bags for simplicity -- worst case
@@ -367,13 +374,13 @@ def _make_copy(self, f, hse, key, swap=False):
367374
name = 'scatter%s' % key
368375

369376
if isinstance(f, Bag):
370-
for i, c0 in enumerate(f.components):
371-
if hse.bundle is not None:
372-
indices = [hse.bundle.components.index(c0), *findices]
373-
rhs = hse.bundle[indices]
374-
else:
375-
rhs = c0[findices]
376-
eqns.append(Eq(*swap(buf[[i] + bdims], rhs)))
377+
for i, c in enumerate(f.components):
378+
eqns.append(Eq(*swap(buf[[i] + bdims], c[findices])))
379+
elif isinstance(f, BundleView):
380+
assert f.parent is hse.bundle
381+
for i, c in enumerate(f.components):
382+
indices = [f.parent.components.index(c), *findices]
383+
eqns.append(Eq(*swap(buf[[i] + bdims], f.parent[indices])))
377384
else:
378385
assert f.is_Bundle
379386
for i in range(f.ncomp):

devito/types/array.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.types.utils import CtypesFactory, DimensionTuple
1111

1212
__all__ = ['Array', 'ArrayMapped', 'ArrayObject', 'PointerArray', 'Bundle',
13-
'ComponentAccess', 'Bag']
13+
'ComponentAccess', 'Bag', 'BundleView']
1414

1515

1616
class ArrayBasic(AbstractFunction, LocalType):
@@ -518,6 +518,31 @@ def handles(self):
518518
return self.components
519519

520520

521+
class BundleView(Bundle):
522+
523+
"""
524+
A BundleView is like a Bundle but it doesn't represent a concrete object
525+
in the generated code. It's used by the compiler to represent a subset
526+
of the components of a Bundle.
527+
"""
528+
529+
__rkwargs__ = Bundle.__rkwargs__ + ('parent',)
530+
531+
def __new__(cls, *args, parent=None, **kwargs):
532+
obj = super().__new__(cls, *args, **kwargs)
533+
obj._parent = parent
534+
535+
return obj
536+
537+
@property
538+
def parent(self):
539+
return self._parent
540+
541+
@property
542+
def handles(self):
543+
return (self.parent,)
544+
545+
521546
class ComponentAccess(Expr, Pickable):
522547

523548
_component_names = ('x', 'y', 'z', 'w')

0 commit comments

Comments
 (0)