Skip to content

Commit 32fa62c

Browse files
committed
compiler: Fix HaloTouch hashing
1 parent a1ebd69 commit 32fa62c

2 files changed

Lines changed: 45 additions & 8 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ class HaloTouch(sympy.Function, Reconstructable):
674674
A SymPy object representing halo accesses through a HaloScheme.
675675
"""
676676

677-
__rargs__ = ('args',)
677+
__rargs__ = ('*args',)
678678
__rkwargs__ = ('halo_scheme',)
679679

680680
def __new__(cls, *args, halo_scheme=None, **kwargs):
@@ -691,10 +691,12 @@ def _sympystr(self, printer):
691691
return str(self)
692692

693693
def __hash__(self):
694-
return hash(self.halo_scheme)
694+
return hash((self.halo_scheme, *super()._hashable_content()))
695695

696696
def __eq__(self, other):
697-
return isinstance(other, HaloTouch) and self.halo_scheme == other.halo_scheme
697+
return (isinstance(other, HaloTouch) and
698+
self.args == other.args and
699+
self.halo_scheme == other.halo_scheme)
698700

699701
func = Reconstructable._rebuild
700702

tests/test_symbolics.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
Min, Max, Real, Imag, Conj, SubDomain, configuration)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
13-
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
14-
CallFromPointer, Cast, DefFunction, FieldFromPointer,
15-
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
16-
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
17-
retrieve_derivatives, BaseCast, SizeOf, VectorAccess)
13+
from devito.mpi.halo_scheme import HaloTouch
14+
from devito.symbolics import (
15+
retrieve_functions, retrieve_indexed, evalrel, CallFromPointer, Cast, # noqa
16+
DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace,
17+
Rvalue, ReservedWord, ListInitializer, uxreplace, pow_to_mul,
18+
retrieve_derivatives, BaseCast, SizeOf, VectorAccess
19+
)
1820
from devito.tools import as_tuple, CustomDtype
1921
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
2022
ComponentAccess, StencilDimension, Symbol as dSymbol)
@@ -520,6 +522,27 @@ def test_vector_access():
520522
assert ccode(v1) == 'VL<g[x, y, z]>'
521523

522524

525+
def test_halo_touch():
526+
grid = Grid(shape=(3, 3))
527+
x, y = grid.dimensions
528+
529+
f = Function(name='f', grid=grid)
530+
g = Function(name='g', grid=grid)
531+
532+
# Hashing and equality
533+
ht0 = HaloTouch(f[x, y], g[x, y])
534+
ht1 = HaloTouch(f[x, y], g[x, y])
535+
ht2 = HaloTouch(f[x, y], g[x + 1, y + 1])
536+
assert hash(ht0) == hash(ht1)
537+
assert ht0 == ht1
538+
assert ht0 != ht2
539+
assert hash(ht0) != hash(ht2)
540+
541+
# Reconstruction
542+
assert ht0 == ht0._rebuild()
543+
assert hash(ht0) == hash(ht0._rebuild())
544+
545+
523546
def test_canonical_ordering_of_weights():
524547
grid = Grid(shape=(3, 3, 3))
525548
x, y, z = grid.dimensions
@@ -670,6 +693,18 @@ def test_reduce_to_number(self):
670693
assert not w_sub.is_Mul
671694
assert w_sub.is_Number
672695

696+
def test_halo_touch(self):
697+
grid = Grid(shape=(3, 3))
698+
x, y = grid.dimensions
699+
700+
f = Function(name='f', grid=grid)
701+
g = Function(name='g', grid=grid)
702+
703+
ht0 = HaloTouch(f[x, y])
704+
ht1 = uxreplace(ht0, {f.indexed: g.indexed})
705+
706+
assert ht1.args == (g[x, y],)
707+
673708

674709
def test_minmax():
675710
grid = Grid(shape=(5, 5))

0 commit comments

Comments
 (0)