Skip to content

Commit 48e0c0a

Browse files
committed
tests: Fix up some halo tests with Buffer
1 parent 2b47ebd commit 48e0c0a

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

tests/test_mpi.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,10 +1887,12 @@ def test_enforce_haloupdate_if_unwritten_function(self, mode):
18871887
@pytest.mark.parallel(mode=1)
18881888
def test_haloupdate_buffer1(self, mode):
18891889
grid = Grid(shape=(4, 4))
1890-
x, y = grid.dimensions
18911890

1892-
u = TimeFunction(name='u', grid=grid, time_order=1, save=Buffer(1))
1893-
v = TimeFunction(name='v', grid=grid, time_order=1, save=Buffer(1))
1891+
# With order 2 (1) forward derivatives, the loops can (and will) be fused,
1892+
# removing the need for a halo update so space_order=4 is used to ensure
1893+
# derivatives are centred resulting in parallel loops and halo updates
1894+
u = TimeFunction(name='u', grid=grid, time_order=1, space_order=4, save=Buffer(1))
1895+
v = TimeFunction(name='v', grid=grid, time_order=1, space_order=4, save=Buffer(1))
18941896

18951897
eqns = [Eq(u.forward, div(v) + 1.),
18961898
Eq(v.forward, div(u.forward) + 1.)]
@@ -1909,22 +1911,22 @@ def test_haloupdate_buffer1(self, mode):
19091911
@pytest.mark.parallel(mode=1)
19101912
@pytest.mark.parametrize('sz,fwd,expr,exp0,exp1,args', [
19111913
(1, True, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
1912-
(1, True, 'Eq(v3.forward, v2.laplace + 1)', 1, 1, ('v2',)),
1913-
(1, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1914-
(2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1914+
(1, True, 'Eq(v3.forward, v2.laplace + 1)', 3, 2, ('v1', 'v2')),
1915+
(1, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')),
1916+
(2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')),
19151917
(1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
1916-
(1, False, 'Eq(v3.backward, v2.laplace + 1)', 1, 1, ('v2',)),
1917-
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1918-
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1918+
(1, False, 'Eq(v3.backward, v2.laplace + 1)', 3, 2, ('v1', 'v2')),
1919+
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
1920+
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
19191921
])
19201922
def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
19211923
grid = Grid((65, 65, 65), topology=('*', 1, '*'))
19221924

1923-
v1 = TimeFunction(name='v1', grid=grid, space_order=2, time_order=1,
1925+
v1 = TimeFunction(name='v1', grid=grid, space_order=4, time_order=1,
19241926
save=Buffer(1))
1925-
v2 = TimeFunction(name='v2', grid=grid, space_order=2, time_order=1,
1927+
v2 = TimeFunction(name='v2', grid=grid, space_order=4, time_order=1,
19261928
save=Buffer(1))
1927-
v3 = TimeFunction(name='v3', grid=grid, space_order=2, time_order=1, # noqa
1929+
v3 = TimeFunction(name='v3', grid=grid, space_order=4, time_order=1, # noqa
19281930
save=Buffer(1))
19291931

19301932
rec = SparseTimeFunction(name='rec', grid=grid, nt=500, npoint=65) # noqa

0 commit comments

Comments
 (0)