Skip to content

Commit 8fd3824

Browse files
authored
Merge pull request #2714 from devitocodes/JDBetteridge/derivative_hotfix
dsl: JDBetteridge/derivative hotfix
2 parents 0d90d06 + b537581 commit 8fd3824

2 files changed

Lines changed: 63 additions & 8 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55

66
import sympy
77

8-
from .finite_difference import generic_derivative, cross_derivative
9-
from .differentiable import Differentiable, diffify, interp_for_fd, Add, Mul
10-
from .tools import direct, transpose
11-
from .rsfd import d45
12-
from devito.tools import (as_mapper, as_tuple, frozendict, is_integer,
13-
Pickable)
8+
from devito.tools import Pickable, as_mapper, as_tuple, frozendict, is_integer
9+
from devito.types.dimension import Dimension
1410
from devito.types.utils import DimensionTuple
1511
from devito.warnings import warn
1612

13+
from .differentiable import Add, Differentiable, Mul, diffify, interp_for_fd
14+
from .finite_difference import cross_derivative, generic_derivative
15+
from .rsfd import d45
16+
from .tools import direct, transpose
17+
1718
__all__ = ['Derivative']
1819

1920

@@ -101,14 +102,30 @@ def __new__(cls, expr, *dims, **kwargs):
101102
# Count the derivatives w.r.t. each variable
102103
dcounter = cls._count_derivatives(deriv_order, dims)
103104

104-
# It's possible that the expr is a `sympy.Number` at this point, which
105+
# It is possible that the expr is a `sympy.Number` at this point, which
105106
# has derivative 0, unless we're taking a 0th derivative.
106107
if isinstance(expr, sympy.Number):
107108
if any(dcounter.values()):
108109
return 0
109110
else:
110111
return expr
111112

113+
# It is also possible that the expression itself is just a
114+
# `devito.Dimension` type which is:
115+
# - derivative 1 if the Dimension coincides and the number of derivatives
116+
# is 1 ie: `Derivative(x, (x, 1)) == 1`.
117+
# - derivative 0 if the Dimension coincides and the total number of
118+
# derivatives is greater than 1 ie: `Derivative(x, (x, 2)) == 0` and
119+
# `Derivative(x, x, y) == 0`.
120+
# - An unevaluated expression otherwise.
121+
if isinstance(expr, Dimension) and expr in dcounter:
122+
if dcounter[expr] == 0:
123+
pass
124+
elif dcounter.pop(expr) == 1 and not dcounter:
125+
return 1
126+
else:
127+
return 0
128+
112129
# Validate the finite difference order `fd_order`
113130
fd_order = cls._validate_fd_order(kwargs.get('fd_order'), expr, dims, dcounter)
114131

@@ -232,7 +249,11 @@ def _validate_fd_order(fd_order, expr, dims, dcounter):
232249
Required: `expr`, `dims`, and the derivative counter to validate.
233250
If not provided, the maximum supported order will be used.
234251
"""
235-
if fd_order is not None:
252+
if isinstance(expr, Dimension):
253+
# If the expression is just a dimension `expr.time_order` and
254+
# `expr.space_order` are not defined
255+
fd_order = (99,)*len(dcounter)
256+
elif fd_order is not None:
236257
# If `fd_order` is specified, then validate
237258
fcounter = defaultdict(int)
238259
# First create a dictionary mapping variable wrt which to differentiate

tests/test_derivatives.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,3 +1250,37 @@ def test_fallback_wrong_custom_size(self):
12501250
exp1 = - .5 * (v._subs(x, x - x.spacing) - v._subs(x, x + x.spacing))/x.spacing
12511251
assert simplify(eq0.rhs - exp0) == 0
12521252
assert simplify(eq1.rhs - exp1) == 0
1253+
1254+
1255+
class TestDimension:
1256+
"""
1257+
Check the few cases where differentiating a dimension is allowed work correctly
1258+
and errors are raised otherwise.
1259+
"""
1260+
1261+
@classmethod
1262+
def setup_class(cls):
1263+
cls.grid = Grid(shape=(11, 11), extent=(1, 1))
1264+
cls.x, cls.y = cls.grid.dimensions
1265+
u = TimeFunction(name='u', grid=cls.grid, space_order=1)
1266+
cls.t = u.time_dim
1267+
1268+
def test_constant(self):
1269+
assert Derivative(self.x, self.x) == 1
1270+
1271+
def test_null(self):
1272+
assert Derivative(self.x, (self.x, 2)) == 0
1273+
assert Derivative(self.x, self.x, self.y) == 0
1274+
assert Derivative(self.x, self.x, self.y) == 0
1275+
assert Derivative(self.x, self.x, self.y, self.t) == 0
1276+
assert Derivative(self.x, self.y, self.t, self.x) == 0
1277+
assert Derivative(self.x, self.y, self.t, (self.x, 2)) == 0
1278+
1279+
def test_unevaluated(self):
1280+
"""
1281+
The following should all be instantiatible without raising an
1282+
exception, but should not simplify.
1283+
"""
1284+
assert Derivative(self.x, self.t)
1285+
assert Derivative(self.x, self.y, self.t)
1286+
assert Derivative(self.x, (self.x, 0))

0 commit comments

Comments
 (0)