Skip to content

Commit f3ed91f

Browse files
committed
dsl: Handle nested derivatives that have specified fd_order
1 parent afa51b2 commit f3ed91f

2 files changed

Lines changed: 81 additions & 10 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,31 @@ def __new__(cls, expr, *dims, **kwargs):
149149
else:
150150
dcounter[d] += o
151151

152-
# Default finite difference orders depending on input dimension (.dt or .dx)
153152
# It's possible that the expr is a `sympy.Number` at this point, which
154153
# has derivative 0, unless we're taking a 0th derivative.
155154
if isinstance(expr, sympy.Number):
156155
if any(dcounter.values()):
157156
return 0
158157
else:
159158
return expr
159+
160+
# Use `fd_order` if specified
161+
fd_order = kwargs.get('fd_order')
162+
if fd_order is not None:
163+
_fd_order_specified = True
164+
# If `fd_order` is specified collect these together
165+
fcounter = defaultdict(int)
166+
for d, o in zip(dims, as_tuple(fd_order)):
167+
fcounter[d] += o
168+
for d, o in dcounter.items():
169+
order = expr.time_order if getattr(d, 'is_Time', False) else expr.space_order
170+
if o > order:
171+
raise ValueError(f'Function does not support {d} derivative of order {o}')
172+
fd_order = fcounter.values()
160173
else:
161-
default_fdo = tuple([
174+
_fd_order_specified = False
175+
# Default finite difference orders depending on input dimension (.dt or .dx)
176+
fd_order = tuple([
162177
expr.time_order
163178
if getattr(d, 'is_Time', False)
164179
else expr.space_order
@@ -174,8 +189,9 @@ def __new__(cls, expr, *dims, **kwargs):
174189
obj = Differentiable.__new__(cls, expr, *derivatives)
175190
obj._dims = tuple(dcounter.keys())
176191

192+
obj._fd_order_specified = _fd_order_specified
177193
obj._fd_order = DimensionTuple(
178-
*as_tuple(kwargs.get('fd_order', default_fdo)),
194+
*as_tuple(fd_order),
179195
getters=obj._dims
180196
)
181197
obj._deriv_order = DimensionTuple(
@@ -533,7 +549,11 @@ def _eval_expand_nest(self, **hints):
533549
# all kwargs from the self object
534550
# TODO: This dictionary merge needs to be a lot better
535551
# EG: Don't actually expand if derivatives are incompatible
536-
new_kwargs = {'deriv_order': tuple(chain(self.deriv_order, expr.deriv_order))}
552+
new_kwargs = {
553+
'deriv_order': tuple(chain(self.deriv_order, expr.deriv_order))
554+
}
555+
if self._fd_order_specified or expr._fd_order_specified:
556+
new_kwargs.update({'fd_order': tuple(chain(self.fd_order, expr.fd_order))})
537557
return self.func(new_expr, *new_dims, **new_kwargs)
538558
else:
539559
return self
@@ -564,3 +584,16 @@ def _eval_expand_add(self, **hints):
564584
return self.func(dep, *self.args[1:])
565585
else:
566586
return self
587+
588+
def _eval_expand_product_rule(self, **hints):
589+
''' Expands products, of functions of the dependent variable
590+
`Derivative(f(x)·g(x), x)
591+
--> Derivative(f(x), x)·g(x) + f(x)·Derivative(g(x), x)`
592+
This is only implemented for first derivatives with an arbitrary number
593+
of multiplicands and second derivatives with two multiplicands. The
594+
resultant expression for higher derivatives and mixed derivatives is much
595+
more difficult to implement.
596+
'''
597+
# expr = self.args[0]
598+
# if isinstance(expr, sympy.Mul):
599+
raise NotImplementedError('Product rule expansion has not been written')

tests/test_derivatives.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,19 +1051,19 @@ def test_deriv_order(self):
10511051
assert du11 == du2
10521052
assert du11.deriv_order == du2.deriv_order
10531053

1054-
@pytest.mark.xfail(raises=ValueError)
10551054
def test_wrong_deriv_order(self):
10561055
''' Check an exception is raises with incompatible arguments
10571056
'''
1058-
_ = Derivative(self.u, self.x, deriv_order=(2, 4))
1057+
with pytest.raises(ValueError):
1058+
_ = Derivative(self.u, self.x, deriv_order=(2, 4))
10591059

1060-
@pytest.mark.xfail(raises=ValueError)
10611060
def test_no_derivative(self):
1062-
_ = Derivative(sympify(-1))
1061+
with pytest.raises(ValueError):
1062+
_ = Derivative(sympify(-1))
10631063

1064-
@pytest.mark.xfail(raises=ValueError)
10651064
def test_no_dimension(self):
1066-
_ = Derivative(sympify(-1), deriv_order=0)
1065+
with pytest.raises(ValueError):
1066+
_ = Derivative(sympify(-1), deriv_order=0)
10671067

10681068
def test_constant(self):
10691069
''' Check constant derivative is zero for non-0th order derivatives
@@ -1144,3 +1144,41 @@ def test_expand_nest(self):
11441144
'''
11451145
expanded = 4*Derivative(self.u, self.x) - 5*Derivative(self.u, (self.x, 2))
11461146
assert self.b.expand(add=True, nest=True) == expanded
1147+
1148+
def test_nested_orders(self):
1149+
''' Check nested expansion results in correct derivative and fd order
1150+
'''
1151+
# Default fd_order
1152+
du22 = Derivative(Derivative(self.u, (self.x, 2)), (self.x, 2))
1153+
du22_expanded = du22.expand(nest=True)
1154+
du4 = Derivative(self.u, (self.x, 4))
1155+
assert du22_expanded == du4
1156+
assert du22_expanded.deriv_order == du4.deriv_order
1157+
assert du22_expanded.fd_order == du4.fd_order
1158+
1159+
# Specified fd_order
1160+
du22 = Derivative(
1161+
Derivative(self.u, (self.x, 2), fd_order=2),
1162+
(self.x, 2),
1163+
fd_order=2
1164+
)
1165+
du22_expanded = du22.expand(nest=True)
1166+
du4 = Derivative(self.u, (self.x, 4), fd_order=4)
1167+
assert du22_expanded == du4
1168+
assert du22_expanded.deriv_order == du4.deriv_order
1169+
assert du22_expanded.fd_order == du4.fd_order
1170+
1171+
# Specified fd_order greater than the space order
1172+
## When no order specified
1173+
du44 = Derivative(Derivative(self.u, (self.x, 4)), (self.x, 4))
1174+
with pytest.raises(ValueError):
1175+
du44_expanded = du44.expand(nest=True)
1176+
1177+
## When order specified is too large
1178+
du44 = Derivative(
1179+
Derivative(self.u, (self.x, 4), fd_order=4),
1180+
(self.x, 4),
1181+
fd_order=4
1182+
)
1183+
with pytest.raises(ValueError):
1184+
du44_expanded = du44.expand(nest=True)

0 commit comments

Comments
 (0)