Skip to content

Commit ca75df1

Browse files
committed
tests: Add some tests for new functionality
1 parent b56e928 commit ca75df1

2 files changed

Lines changed: 159 additions & 16 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ def __new__(cls, expr, *dims, **kwargs):
105105
expr = diffify(expr)
106106
except Exception as e:
107107
raise ValueError("`expr` must be a Differentiable object") from e
108-
if isinstance(expr, sympy.Number):
109-
return 0
110108

111109
# Validate `dims`. It can be:
112110
# - a single Dimension ie: x
@@ -131,22 +129,39 @@ def __new__(cls, expr, *dims, **kwargs):
131129
raise ValueError(
132130
'Length of `deriv_order` does not match the length of dimensions'
133131
)
132+
if any([not is_integer(d) or d < 0 for d in deriv_order]):
133+
raise TypeError(
134+
'Invalid type in `deriv_order`, all elements must be non-negative Python `int`s'
135+
)
134136

135137
# Count the number of derivatives for each dimension
136138
dcounter = defaultdict(int)
137139
for d, o in zip(dims, deriv_order):
138140
if isinstance(d, Iterable):
139-
dcounter[d[0]] += d[1]
141+
if not is_integer(d[1]) or d[1] < 0:
142+
raise TypeError(
143+
'Invalid type for derivative order, it must be non-negative Python `int`'
144+
)
145+
else:
146+
dcounter[d[0]] += d[1]
140147
else:
141148
dcounter[d] += o
142149

143150
# Default finite difference orders depending on input dimension (.dt or .dx)
144-
default_fdo = tuple([
145-
expr.time_order
146-
if getattr(d, 'is_Time', False)
147-
else expr.space_order
148-
for d in dcounter.keys()
149-
])
151+
# It's possible that the expr is a `sympy.Number` at this point, which
152+
# has derivative 0, unless we're taking a 0th derivative.
153+
if isinstance(expr, sympy.Number):
154+
if any(dcounter.values()):
155+
return 0
156+
else:
157+
return expr
158+
else:
159+
default_fdo = tuple([
160+
expr.time_order
161+
if getattr(d, 'is_Time', False)
162+
else expr.space_order
163+
for d in dcounter.keys()
164+
])
150165

151166
# SymPy expects the list of variable w.r.t. which we differentiate to be a list
152167
# of 2-tuple `(s, count)` where s is the entity to diff wrt and count is the order
@@ -253,12 +268,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, **kwargs):
253268
return self._rebuild(*self.args, **rkw)
254269

255270
def _rebuild(self, *args, **kwargs):
256-
if not args:
257-
kwargs['preprocessed'] = True
258-
expr = super()._rebuild(**kwargs)
259-
else:
260-
expr = super()._rebuild(*args, **kwargs)
261-
return expr
271+
return super()._rebuild(*args, **kwargs)
262272

263273
func = _rebuild
264274

@@ -495,6 +505,12 @@ def _eval_fd(self, expr, **kwargs):
495505
return res
496506

497507
def _eval_expand_nest(self, **hints):
508+
''' Expands nested derivatives
509+
`Derivative(Derivative(f(x), (x, b)), (x, a)) --> Derivative(f(x), (x, a+b))`
510+
`Derivative(Derivative(f(x), (y, b)), (x, a)) --> Derivative(f(x), (x, a), (y, b))`
511+
Note that this is not always a valid expansion depending on the kwargs
512+
used to construct the derivative.
513+
'''
498514
expr = self.args[0]
499515
if isinstance(expr, self.__class__):
500516
new_expr = expr.args[0]
@@ -514,6 +530,9 @@ def _eval_expand_nest(self, **hints):
514530
return self
515531

516532
def _eval_expand_mul(self, **hints):
533+
''' Expands products, moving independent terms outside the derivative
534+
`Derivative(C·f(x)·g(c, y), x) --> C·g(y)·Derivative(f(x), x)`
535+
'''
517536
expr = self.args[0]
518537
if isinstance(expr, sympy.Mul):
519538
ind, dep = expr.as_independent(*self.dims, as_Mul=True)
@@ -522,6 +541,9 @@ def _eval_expand_mul(self, **hints):
522541
return self
523542

524543
def _eval_expand_add(self, **hints):
544+
''' Expands sums, using linearity of derivative
545+
`Derivative(f(x) + g(x), x) --> Derivative(f(x), x) + Derivative(g(x), x)`
546+
'''
525547
expr = self.args[0]
526548
if isinstance(expr, sympy.Add):
527549
ind, dep = expr.as_independent(*self.dims, as_Add=True)

tests/test_derivatives.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import pytest
3-
from sympy import simplify, diff, Float
3+
from sympy import sympify, simplify, diff, Float, Symbol
44

55
from devito import (Grid, Function, TimeFunction, Eq, Operator, NODE, cos, sin,
66
ConditionalDimension, left, right, centered, div, grad)
@@ -1023,3 +1023,124 @@ def bypass_uneval(expr):
10231023
unevals = expr.find(EvalDerivative)
10241024
mapper = {i: Add(*i.args) for i in unevals}
10251025
return expr.xreplace(mapper)
1026+
1027+
1028+
class TestExpansion:
1029+
@classmethod
1030+
def setup_class(cls):
1031+
cls.grid = Grid(shape=(11,), extent=(1,))
1032+
cls.x = cls.grid.dimensions[0]
1033+
cls.u = Function(name='u', grid=cls.grid, space_order=4)
1034+
1035+
a = cls.u.dx
1036+
cls.b = a.subs({cls.u: -5*cls.u.dx + 4*cls.u + 3})
1037+
1038+
def test_reconstructible(self):
1039+
''' Check that devito.Derivatives are reconstructible from func and args
1040+
(as per sympy docs)
1041+
'''
1042+
du = self.u.dx
1043+
assert du.func(*du.args) == du
1044+
assert du.func(*du.args).args == (self.u, (self.x, 1))
1045+
1046+
def test_deriv_order(self):
1047+
''' Check default simplification causes the same result
1048+
'''
1049+
du11 = Derivative(self.u, self.x, self.x)
1050+
du2 = Derivative(self.u, (self.x, 2))
1051+
assert du11 == du2
1052+
assert du11.deriv_order == du2.deriv_order
1053+
1054+
@pytest.mark.xfail(raises=ValueError)
1055+
def test_wrong_deriv_order(self):
1056+
''' Check an exception is raises with incompatible arguments
1057+
'''
1058+
_ = Derivative(self.u, self.x, deriv_order=(2, 4))
1059+
1060+
@pytest.mark.xfail(raises=ValueError)
1061+
def test_no_derivative(self):
1062+
_ = Derivative(sympify(-1))
1063+
1064+
@pytest.mark.xfail(raises=ValueError)
1065+
def test_no_dimension(self):
1066+
_ = Derivative(sympify(-1), deriv_order=0)
1067+
1068+
def test_constant(self):
1069+
''' Check constant derivative is zero for non-0th order derivatives
1070+
'''
1071+
assert Derivative(sympify(-1), (self.x, 1)) == 0
1072+
assert Derivative(sympify(-1), (self.x, 2)) == 0
1073+
assert Derivative(sympify(-1), (self.x, 0)) == -1
1074+
1075+
def test_dims_validation(self):
1076+
''' Validate `dims` kwarg
1077+
'''
1078+
grid = Grid(shape=(11, 11, 11), extent=(1, 1, 1))
1079+
x, y, z = grid.dimensions
1080+
u = Function(name='u', grid=grid, space_order=4)
1081+
1082+
d = Derivative(u, x)
1083+
assert d.dims == (x, )
1084+
assert d.deriv_order == (1, )
1085+
1086+
d = Derivative(u, x, y)
1087+
assert d.dims == (x, y)
1088+
assert d.deriv_order == (1, 1)
1089+
1090+
d = Derivative(u, (x, 2))
1091+
assert d.dims == (x, )
1092+
assert d.deriv_order == (2, )
1093+
1094+
d = Derivative(u, (x, 2), (y, 2))
1095+
assert d.dims == (x, y)
1096+
assert d.deriv_order == (2, 2)
1097+
1098+
d = Derivative(u, (x, 2), y, x, (z, 3))
1099+
assert d.dims == (x, y, z)
1100+
assert d.deriv_order == (3, 1, 3)
1101+
1102+
def test_dims_exceptions(self):
1103+
''' Check invalid dimensions and orders raise exceptions
1104+
'''
1105+
grid = Grid(shape=(11, 11, 11), extent=(1, 1, 1))
1106+
x, y, z = grid.dimensions
1107+
u = Function(name='u', grid=grid, space_order=4)
1108+
1109+
# Don't allow negative derivatives
1110+
with pytest.raises(TypeError):
1111+
_ = Derivative(u, (x, -1))
1112+
1113+
# Don't allow fractional derivatives
1114+
with pytest.raises(TypeError):
1115+
_ = Derivative(u, (x, 0.5))
1116+
1117+
# Don't allow common mistake
1118+
# NB: Derivative(u, x, y) is probably what was intended
1119+
with pytest.raises(TypeError):
1120+
_ = Derivative(u, (x, y))
1121+
1122+
# Don't allow derivative order to be symbolic
1123+
a = Symbol('a', integer=True)
1124+
with pytest.raises(TypeError):
1125+
_ = Derivative(u, (x, a))
1126+
1127+
def test_expand_mul(self):
1128+
''' Check independent terms can be extracted from the derivative.
1129+
The multiply expansion is the only hint executed by default when
1130+
`.expand()` is called.
1131+
'''
1132+
expanded = Derivative(4*self.u - 5*Derivative(self.u, self.x) + 3, self.x)
1133+
assert self.b.expand() == expanded
1134+
1135+
def test_expand_add(self):
1136+
''' Check linearity
1137+
'''
1138+
expanded = 4*Derivative(self.u, self.x)
1139+
expanded -= 5*Derivative(Derivative(self.u, self.x), self.x)
1140+
assert self.b.expand(add=True) == expanded
1141+
1142+
def test_expand_nest(self):
1143+
''' Check valid nested derivative expands (combining x derivatives)
1144+
'''
1145+
expanded = 4*Derivative(self.u, self.x) - 5*Derivative(self.u, (self.x, 2))
1146+
assert self.b.expand(add=True, nest=True) == expanded

0 commit comments

Comments
 (0)