@@ -1887,10 +1887,12 @@ def test_enforce_haloupdate_if_unwritten_function(self, mode):
18871887 @pytest .mark .parallel (mode = 1 )
18881888 def test_haloupdate_buffer1 (self , mode ):
18891889 grid = Grid (shape = (4 , 4 ))
1890- x , y = grid .dimensions
18911890
1892- u = TimeFunction (name = 'u' , grid = grid , time_order = 1 , save = Buffer (1 ))
1893- v = TimeFunction (name = 'v' , grid = grid , time_order = 1 , save = Buffer (1 ))
1891+ # With order 2 (1) forward derivatives, the loops can (and will) be fused,
1892+ # removing the need for a halo update so space_order=4 is used to ensure
1893+ # derivatives are centred resulting in parallel loops and halo updates
1894+ u = TimeFunction (name = 'u' , grid = grid , time_order = 1 , space_order = 4 , save = Buffer (1 ))
1895+ v = TimeFunction (name = 'v' , grid = grid , time_order = 1 , space_order = 4 , save = Buffer (1 ))
18941896
18951897 eqns = [Eq (u .forward , div (v ) + 1. ),
18961898 Eq (v .forward , div (u .forward ) + 1. )]
@@ -1909,22 +1911,22 @@ def test_haloupdate_buffer1(self, mode):
19091911 @pytest .mark .parallel (mode = 1 )
19101912 @pytest .mark .parametrize ('sz,fwd,expr,exp0,exp1,args' , [
19111913 (1 , True , 'rec.interpolate(v2)' , 3 , 2 , ('v1' , 'v2' )),
1912- (1 , True , 'Eq(v3.forward, v2.laplace + 1)' , 1 , 1 , ('v2' , )),
1913- (1 , True , 'Eq(v3.forward, v2.forward.laplace + 1)' , 3 , 2 , ('v1' , 'v2' , )),
1914- (2 , True , 'Eq(v3.forward, v2.forward.laplace + 1)' , 3 , 2 , ('v1' , 'v2' , )),
1914+ (1 , True , 'Eq(v3.forward, v2.laplace + 1)' , 3 , 2 , ('v1' , 'v2' )),
1915+ (1 , True , 'Eq(v3.forward, v2.forward.laplace + 1)' , 3 , 2 , ('v1' , 'v2' )),
1916+ (2 , True , 'Eq(v3.forward, v2.forward.laplace + 1)' , 3 , 2 , ('v1' , 'v2' )),
19151917 (1 , False , 'rec.interpolate(v2)' , 3 , 2 , ('v1' , 'v2' )),
1916- (1 , False , 'Eq(v3.backward, v2.laplace + 1)' , 1 , 1 , ('v2' , )),
1917- (1 , False , 'Eq(v3.backward, v2.backward.laplace + 1)' , 3 , 2 , ('v1' , 'v2' , )),
1918- (2 , False , 'Eq(v3.backward, v2.backward.laplace + 1)' , 3 , 2 , ('v1' , 'v2' , )),
1918+ (1 , False , 'Eq(v3.backward, v2.laplace + 1)' , 3 , 2 , ('v1' , 'v2' )),
1919+ (1 , False , 'Eq(v3.backward, v2.backward.laplace + 1)' , 3 , 3 , ('v2' , ' v1' , 'v2' )),
1920+ (2 , False , 'Eq(v3.backward, v2.backward.laplace + 1)' , 3 , 3 , ('v2' , ' v1' , 'v2' )),
19191921 ])
19201922 def test_haloupdate_buffer_cases (self , sz , fwd , expr , exp0 , exp1 , args , mode ):
19211923 grid = Grid ((65 , 65 , 65 ), topology = ('*' , 1 , '*' ))
19221924
1923- v1 = TimeFunction (name = 'v1' , grid = grid , space_order = 2 , time_order = 1 ,
1925+ v1 = TimeFunction (name = 'v1' , grid = grid , space_order = 4 , time_order = 1 ,
19241926 save = Buffer (1 ))
1925- v2 = TimeFunction (name = 'v2' , grid = grid , space_order = 2 , time_order = 1 ,
1927+ v2 = TimeFunction (name = 'v2' , grid = grid , space_order = 4 , time_order = 1 ,
19261928 save = Buffer (1 ))
1927- v3 = TimeFunction (name = 'v3' , grid = grid , space_order = 2 , time_order = 1 , # noqa
1929+ v3 = TimeFunction (name = 'v3' , grid = grid , space_order = 4 , time_order = 1 , # noqa
19281930 save = Buffer (1 ))
19291931
19301932 rec = SparseTimeFunction (name = 'rec' , grid = grid , nt = 500 , npoint = 65 ) # noqa
0 commit comments