|
18 | 18 | IndexedPointer, Macro, cast, subs_op_args) |
19 | 19 | from devito.tools import (as_mapper, dtype_to_mpitype, dtype_len, infer_datasize, |
20 | 20 | flatten, generator, is_integer) |
21 | | -from devito.types import (Array, Bag, Dimension, Eq, Symbol, LocalObject, |
22 | | - CompositeObject, CustomDimension) |
| 21 | +from devito.types import (Array, Bag, BundleView, Dimension, Eq, Symbol, |
| 22 | + LocalObject, CompositeObject, CustomDimension) |
23 | 23 |
|
24 | 24 | __all__ = ['HaloExchangeBuilder', 'ReductionBuilder', 'mpi_registry'] |
25 | 25 |
|
@@ -295,8 +295,15 @@ def _make_bundles(self, hs): |
295 | 295 | halo_scheme = halo_scheme.drop(components) |
296 | 296 |
|
297 | 297 | # Existing Bundles are preserved |
298 | | - if hse.bundle and set(components) == set(hse.bundle.components): |
299 | | - halo_scheme = halo_scheme.add(hse.bundle, hse) |
| 298 | + if hse.bundle: |
| 299 | + if set(components) == set(hse.bundle.components): |
| 300 | + halo_scheme = halo_scheme.add(hse.bundle, hse) |
| 301 | + else: |
| 302 | + name = f'bundleview_{hse.bundle.name}' |
| 303 | + bundle_view = BundleView( |
| 304 | + name=name, components=components, parent=hse.bundle |
| 305 | + ) |
| 306 | + halo_scheme = halo_scheme.add(bundle_view, hse) |
300 | 307 | continue |
301 | 308 |
|
302 | 309 | # We recast everything else as Bags for simplicity -- worst case |
@@ -367,13 +374,13 @@ def _make_copy(self, f, hse, key, swap=False): |
367 | 374 | name = 'scatter%s' % key |
368 | 375 |
|
369 | 376 | if isinstance(f, Bag): |
370 | | - for i, c0 in enumerate(f.components): |
371 | | - if hse.bundle is not None: |
372 | | - indices = [hse.bundle.components.index(c0), *findices] |
373 | | - rhs = hse.bundle[indices] |
374 | | - else: |
375 | | - rhs = c0[findices] |
376 | | - eqns.append(Eq(*swap(buf[[i] + bdims], rhs))) |
| 377 | + for i, c in enumerate(f.components): |
| 378 | + eqns.append(Eq(*swap(buf[[i] + bdims], c[findices]))) |
| 379 | + elif isinstance(f, BundleView): |
| 380 | + assert f.parent is hse.bundle |
| 381 | + for i, c in enumerate(f.components): |
| 382 | + indices = [f.parent.components.index(c), *findices] |
| 383 | + eqns.append(Eq(*swap(buf[[i] + bdims], f.parent[indices]))) |
377 | 384 | else: |
378 | 385 | assert f.is_Bundle |
379 | 386 | for i in range(f.ncomp): |
|
0 commit comments