|
4 | 4 |
|
5 | 5 | from devito import (Grid, Function, TimeFunction, Eq, Operator, NODE, cos, sin, |
6 | 6 | ConditionalDimension, left, right, centered, div, grad) |
7 | | -from devito.finite_differences import Derivative, Differentiable |
| 7 | +from devito.finite_differences import Derivative, Differentiable, diffify |
8 | 8 | from devito.finite_differences.differentiable import (Add, EvalDerivative, IndexSum, |
9 | 9 | IndexDerivative, Weights, |
10 | 10 | DiffDerivative) |
@@ -1030,7 +1030,12 @@ class TestExpansion: |
1030 | 1030 | def setup_class(cls): |
1031 | 1031 | cls.grid = Grid(shape=(11,), extent=(1,)) |
1032 | 1032 | cls.x = cls.grid.dimensions[0] |
1033 | | - cls.u = Function(name='u', grid=cls.grid, space_order=4) |
| 1033 | + cls.u = TimeFunction(name='u', grid=cls.grid, space_order=4) |
| 1034 | + cls.t = cls.u.time_dim |
| 1035 | + |
| 1036 | + cls.v = Function(name='v', grid=cls.grid, space_order=4) |
| 1037 | + cls.w = Function(name='w', grid=cls.grid, space_order=4) |
| 1038 | + cls.f = Function(name='f', dimensions=(cls.t,), shape=(5,)) |
1034 | 1039 |
|
1035 | 1040 | # Note that using the `.dx` shortcut method specifies the fd_order kwarg |
1036 | 1041 | a = cls.u.dx |
@@ -1140,6 +1145,10 @@ def test_expand_mul(self): |
1140 | 1145 | expanded = Derivative(4*self.u - 5*Derivative(self.u, self.x) + 3, self.x) |
1141 | 1146 | assert self.b.expand() == expanded |
1142 | 1147 |
|
| 1148 | + expr = self.u.dx.subs({self.u: 5*self.f*self.u*self.v*self.w}) |
| 1149 | + expanded = 5*self.f*Derivative(self.u*self.v*self.w, self.x) |
| 1150 | + assert diffify(expr.expand()) == expanded |
| 1151 | + |
1143 | 1152 | def test_expand_add(self): |
1144 | 1153 | """ |
1145 | 1154 | Check linearity |
@@ -1194,3 +1203,19 @@ def test_nested_orders(self): |
1194 | 1203 | ) |
1195 | 1204 | with pytest.raises(ValueError): |
1196 | 1205 | _ = du44.expand(nest=True) |
| 1206 | + |
| 1207 | + def test_expand_product_rule(self): |
| 1208 | + """ |
| 1209 | + Check the two implemented cases for product rule expansion |
| 1210 | + """ |
| 1211 | + expr = self.u.dx.subs({self.u: 5*self.f*self.u*self.v*self.w}, postprocess=False) |
| 1212 | + expanded = 5*self.f*self.v*self.w*Derivative(self.u, self.x) \ |
| 1213 | + + 5*self.f*self.v*self.u*Derivative(self.w, self.x) \ |
| 1214 | + + 5*self.f*self.w*self.u*Derivative(self.v, self.x) |
| 1215 | + assert diffify(expr.expand(product_rule=True)) == expanded |
| 1216 | + |
| 1217 | + expr = self.u.dx2.subs({self.u: 5*self.f*self.u*self.v}, postprocess=False) |
| 1218 | + expanded = 5*self.f*self.v*Derivative(self.u, (self.x, 2)) \ |
| 1219 | + + 10*self.f*Derivative(self.v, self.x)*Derivative(self.u, self.x) \ |
| 1220 | + + 5*self.f*self.u*Derivative(self.v, (self.x, 2)) |
| 1221 | + assert diffify(expr.expand(product_rule=True)) == expanded |
0 commit comments