@@ -780,6 +780,44 @@ def test_param_stagg_inner(self):
780780 eqne = eqn .evaluate .rhs
781781 assert simplify (eqne - (p ._subs (y , yp ).evaluate * f ).dx (x0 = xp ).evaluate ) == 0
782782
783+ def test_param_stagg_add (self ):
784+ space_order = 2
785+ nx , ny = 5 , 5
786+ extent = (nx - 1 , ny - 1 )
787+
788+ grid = Grid (shape = (nx , ny ), extent = extent )
789+ x , y = grid .dimensions
790+ yp = y + y .spacing / 2
791+ xp = x + x .spacing / 2
792+
793+ x , y = grid .dimensions
794+
795+ vx = TimeFunction (name = "vx" , grid = grid , space_order = space_order ,
796+ time_order1 = 1 , staggered = x )
797+ txx = TimeFunction (name = "txx" , grid = grid , space_order = space_order ,
798+ time_order = 1 , staggered = NODE )
799+ txy = TimeFunction (name = "txy" , grid = grid , space_order = space_order ,
800+ time_order = 1 , staggered = (x , y ))
801+ c11 = Function (name = "c11" , grid = grid , space_order = space_order , parameter = True )
802+ c66 = Function (name = "c66" , grid = grid , space_order = space_order , parameter = True )
803+
804+ eq0 = Eq (vx , (c66 * txy ).dy )
805+ eq1 = Eq (vx , (c11 * txx ).dy )
806+ eq2 = Eq (vx , (c11 * txx + c66 * txy ).dy )
807+
808+ # C66 is a paramater. Expects to evaluate c66 at xp then the derivative at yp
809+ # and the derivative will interpolate txy at xp
810+ expect0 = (c66 .subs ({x : xp , y : yp }).evaluate * txy ).dy .evaluate
811+ assert simplify (eq0 .evaluate .rhs - expect0 ) == 0
812+
813+ # C11 is a paramater and txy is staggered in x.
814+ # Expects to evaluate c11 and txy xp then the derivative at yp
815+ expect1 = (c11 ._subs (x , xp ).evaluate * txx ._subs (x , xp ).evaluate ).dy .evaluate
816+ assert simplify (eq1 .evaluate .rhs - expect1 ) == 0
817+
818+ # Addition should apply the same logic as above for each term
819+ assert simplify (eq2 .evaluate .rhs - (expect1 + expect0 )) == 0
820+
783821
784822class TestTwoStageEvaluation :
785823
0 commit comments