|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
3 | | -from sympy import simplify, diff, Float |
| 3 | +from sympy import sympify, simplify, diff, Float, Symbol |
4 | 4 |
|
5 | 5 | from devito import (Grid, Function, TimeFunction, Eq, Operator, NODE, cos, sin, |
6 | 6 | ConditionalDimension, left, right, centered, div, grad) |
@@ -1023,3 +1023,124 @@ def bypass_uneval(expr): |
1023 | 1023 | unevals = expr.find(EvalDerivative) |
1024 | 1024 | mapper = {i: Add(*i.args) for i in unevals} |
1025 | 1025 | return expr.xreplace(mapper) |
| 1026 | + |
| 1027 | + |
| 1028 | +class TestExpansion: |
| 1029 | + @classmethod |
| 1030 | + def setup_class(cls): |
| 1031 | + cls.grid = Grid(shape=(11,), extent=(1,)) |
| 1032 | + cls.x = cls.grid.dimensions[0] |
| 1033 | + cls.u = Function(name='u', grid=cls.grid, space_order=4) |
| 1034 | + |
| 1035 | + a = cls.u.dx |
| 1036 | + cls.b = a.subs({cls.u: -5*cls.u.dx + 4*cls.u + 3}) |
| 1037 | + |
| 1038 | + def test_reconstructible(self): |
| 1039 | + ''' Check that devito.Derivatives are reconstructible from func and args |
| 1040 | + (as per sympy docs) |
| 1041 | + ''' |
| 1042 | + du = self.u.dx |
| 1043 | + assert du.func(*du.args) == du |
| 1044 | + assert du.func(*du.args).args == (self.u, (self.x, 1)) |
| 1045 | + |
| 1046 | + def test_deriv_order(self): |
| 1047 | + ''' Check default simplification causes the same result |
| 1048 | + ''' |
| 1049 | + du11 = Derivative(self.u, self.x, self.x) |
| 1050 | + du2 = Derivative(self.u, (self.x, 2)) |
| 1051 | + assert du11 == du2 |
| 1052 | + assert du11.deriv_order == du2.deriv_order |
| 1053 | + |
| 1054 | + @pytest.mark.xfail(raises=ValueError) |
| 1055 | + def test_wrong_deriv_order(self): |
| 1056 | + ''' Check an exception is raises with incompatible arguments |
| 1057 | + ''' |
| 1058 | + _ = Derivative(self.u, self.x, deriv_order=(2, 4)) |
| 1059 | + |
| 1060 | + @pytest.mark.xfail(raises=ValueError) |
| 1061 | + def test_no_derivative(self): |
| 1062 | + _ = Derivative(sympify(-1)) |
| 1063 | + |
| 1064 | + @pytest.mark.xfail(raises=ValueError) |
| 1065 | + def test_no_dimension(self): |
| 1066 | + _ = Derivative(sympify(-1), deriv_order=0) |
| 1067 | + |
| 1068 | + def test_constant(self): |
| 1069 | + ''' Check constant derivative is zero for non-0th order derivatives |
| 1070 | + ''' |
| 1071 | + assert Derivative(sympify(-1), (self.x, 1)) == 0 |
| 1072 | + assert Derivative(sympify(-1), (self.x, 2)) == 0 |
| 1073 | + assert Derivative(sympify(-1), (self.x, 0)) == -1 |
| 1074 | + |
| 1075 | + def test_dims_validation(self): |
| 1076 | + ''' Validate `dims` kwarg |
| 1077 | + ''' |
| 1078 | + grid = Grid(shape=(11, 11, 11), extent=(1, 1, 1)) |
| 1079 | + x, y, z = grid.dimensions |
| 1080 | + u = Function(name='u', grid=grid, space_order=4) |
| 1081 | + |
| 1082 | + d = Derivative(u, x) |
| 1083 | + assert d.dims == (x, ) |
| 1084 | + assert d.deriv_order == (1, ) |
| 1085 | + |
| 1086 | + d = Derivative(u, x, y) |
| 1087 | + assert d.dims == (x, y) |
| 1088 | + assert d.deriv_order == (1, 1) |
| 1089 | + |
| 1090 | + d = Derivative(u, (x, 2)) |
| 1091 | + assert d.dims == (x, ) |
| 1092 | + assert d.deriv_order == (2, ) |
| 1093 | + |
| 1094 | + d = Derivative(u, (x, 2), (y, 2)) |
| 1095 | + assert d.dims == (x, y) |
| 1096 | + assert d.deriv_order == (2, 2) |
| 1097 | + |
| 1098 | + d = Derivative(u, (x, 2), y, x, (z, 3)) |
| 1099 | + assert d.dims == (x, y, z) |
| 1100 | + assert d.deriv_order == (3, 1, 3) |
| 1101 | + |
| 1102 | + def test_dims_exceptions(self): |
| 1103 | + ''' Check invalid dimensions and orders raise exceptions |
| 1104 | + ''' |
| 1105 | + grid = Grid(shape=(11, 11, 11), extent=(1, 1, 1)) |
| 1106 | + x, y, z = grid.dimensions |
| 1107 | + u = Function(name='u', grid=grid, space_order=4) |
| 1108 | + |
| 1109 | + # Don't allow negative derivatives |
| 1110 | + with pytest.raises(TypeError): |
| 1111 | + _ = Derivative(u, (x, -1)) |
| 1112 | + |
| 1113 | + # Don't allow fractional derivatives |
| 1114 | + with pytest.raises(TypeError): |
| 1115 | + _ = Derivative(u, (x, 0.5)) |
| 1116 | + |
| 1117 | + # Don't allow common mistake |
| 1118 | + # NB: Derivative(u, x, y) is probably what was intended |
| 1119 | + with pytest.raises(TypeError): |
| 1120 | + _ = Derivative(u, (x, y)) |
| 1121 | + |
| 1122 | + # Don't allow derivative order to be symbolic |
| 1123 | + a = Symbol('a', integer=True) |
| 1124 | + with pytest.raises(TypeError): |
| 1125 | + _ = Derivative(u, (x, a)) |
| 1126 | + |
| 1127 | + def test_expand_mul(self): |
| 1128 | + ''' Check independent terms can be extracted from the derivative. |
| 1129 | + The multiply expansion is the only hint executed by default when |
| 1130 | + `.expand()` is called. |
| 1131 | + ''' |
| 1132 | + expanded = Derivative(4*self.u - 5*Derivative(self.u, self.x) + 3, self.x) |
| 1133 | + assert self.b.expand() == expanded |
| 1134 | + |
| 1135 | + def test_expand_add(self): |
| 1136 | + ''' Check linearity |
| 1137 | + ''' |
| 1138 | + expanded = 4*Derivative(self.u, self.x) |
| 1139 | + expanded -= 5*Derivative(Derivative(self.u, self.x), self.x) |
| 1140 | + assert self.b.expand(add=True) == expanded |
| 1141 | + |
| 1142 | + def test_expand_nest(self): |
| 1143 | + ''' Check valid nested derivative expands (combining x derivatives) |
| 1144 | + ''' |
| 1145 | + expanded = 4*Derivative(self.u, self.x) - 5*Derivative(self.u, (self.x, 2)) |
| 1146 | + assert self.b.expand(add=True, nest=True) == expanded |
0 commit comments