|
10 | 10 | Min, Max, Real, Imag, Conj, SubDomain, configuration) |
11 | 11 | from devito.finite_differences.differentiable import SafeInv, Weights, Mul |
12 | 12 | from devito.ir import Expression, FindNodes, ccode |
13 | | -from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa |
14 | | - CallFromPointer, Cast, DefFunction, FieldFromPointer, |
15 | | - INT, FieldFromComposite, IntDiv, Namespace, Rvalue, |
16 | | - ReservedWord, ListInitializer, uxreplace, pow_to_mul, |
17 | | - retrieve_derivatives, BaseCast, SizeOf, VectorAccess) |
| 13 | +from devito.mpi.halo_scheme import HaloTouch |
| 14 | +from devito.symbolics import ( |
| 15 | + retrieve_functions, retrieve_indexed, evalrel, CallFromPointer, Cast, # noqa |
| 16 | + DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, |
| 17 | + Rvalue, ReservedWord, ListInitializer, uxreplace, pow_to_mul, |
| 18 | + retrieve_derivatives, BaseCast, SizeOf, VectorAccess |
| 19 | +) |
18 | 20 | from devito.tools import as_tuple, CustomDtype |
19 | 21 | from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, |
20 | 22 | ComponentAccess, StencilDimension, Symbol as dSymbol) |
@@ -520,6 +522,27 @@ def test_vector_access(): |
520 | 522 | assert ccode(v1) == 'VL<g[x, y, z]>' |
521 | 523 |
|
522 | 524 |
|
| 525 | +def test_halo_touch(): |
| 526 | + grid = Grid(shape=(3, 3)) |
| 527 | + x, y = grid.dimensions |
| 528 | + |
| 529 | + f = Function(name='f', grid=grid) |
| 530 | + g = Function(name='g', grid=grid) |
| 531 | + |
| 532 | + # Hashing and equality |
| 533 | + ht0 = HaloTouch(f[x, y], g[x, y]) |
| 534 | + ht1 = HaloTouch(f[x, y], g[x, y]) |
| 535 | + ht2 = HaloTouch(f[x, y], g[x + 1, y + 1]) |
| 536 | + assert hash(ht0) == hash(ht1) |
| 537 | + assert ht0 == ht1 |
| 538 | + assert ht0 != ht2 |
| 539 | + assert hash(ht0) != hash(ht2) |
| 540 | + |
| 541 | + # Reconstruction |
| 542 | + assert ht0 == ht0._rebuild() |
| 543 | + assert hash(ht0) == hash(ht0._rebuild()) |
| 544 | + |
| 545 | + |
523 | 546 | def test_canonical_ordering_of_weights(): |
524 | 547 | grid = Grid(shape=(3, 3, 3)) |
525 | 548 | x, y, z = grid.dimensions |
@@ -670,6 +693,18 @@ def test_reduce_to_number(self): |
670 | 693 | assert not w_sub.is_Mul |
671 | 694 | assert w_sub.is_Number |
672 | 695 |
|
| 696 | + def test_halo_touch(self): |
| 697 | + grid = Grid(shape=(3, 3)) |
| 698 | + x, y = grid.dimensions |
| 699 | + |
| 700 | + f = Function(name='f', grid=grid) |
| 701 | + g = Function(name='g', grid=grid) |
| 702 | + |
| 703 | + ht0 = HaloTouch(f[x, y]) |
| 704 | + ht1 = uxreplace(ht0, {f.indexed: g.indexed}) |
| 705 | + |
| 706 | + assert ht1.args == (g[x, y],) |
| 707 | + |
673 | 708 |
|
674 | 709 | def test_minmax(): |
675 | 710 | grid = Grid(shape=(5, 5)) |
|
0 commit comments