Skip to content

Commit 74de496

Browse files
authored
Merge pull request #2672 from devitocodes/xderiv-fix
api: fix custom coeffs with cross derivatives
2 parents aed4fa6 + 3ad8470 commit 74de496

5 files changed

Lines changed: 58 additions & 19 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _process_weights(cls, **kwargs):
213213
weights = kwargs.get('weights', kwargs.get('w'))
214214
if weights is None:
215215
return None
216-
elif isinstance(weights, sympy.Function):
216+
elif isinstance(weights, Differentiable):
217217
return weights
218218
else:
219219
return as_tuple(weights)
@@ -484,7 +484,7 @@ def _eval_fd(self, expr, **kwargs):
484484
assert self.method == 'FD'
485485
res = cross_derivative(expr, self.dims, self.fd_order, self.deriv_order,
486486
matvec=self.transpose, x0=x0_deriv, expand=expand,
487-
side=self.side)
487+
side=self.side, weights=self.weights)
488488
else:
489489
assert self.method == 'FD'
490490
res = generic_derivative(expr, self.dims[0], self.fd_order[0],

devito/finite_differences/finite_difference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def cross_derivative(expr, dims, fd_order, deriv_order, x0=None, side=None, **kw
5959
Semantically, this is equivalent to
6060
6161
>>> (f*g).dxdy
62-
Derivative(Derivative(f(x, y)*g(x, y), x), y)
62+
Derivative(f(x, y)*g(x, y), x, y)
6363
6464
The only difference is that in the latter case derivatives remain unevaluated.
6565
The expanded form is obtained via ``evaluate``
@@ -158,7 +158,9 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
158158
# `coefficients` method (`taylor` or `symbolic`)
159159
if weights is None:
160160
weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0)
161-
elif wdim is not None:
161+
# Did fd_weights_registry return a new Function/Expression instead of a values?
162+
_, wdim, _ = process_weights(weights, expr, dim)
163+
if wdim is not None:
162164
weights = [weights._subs(wdim, i) for i in range(len(indices))]
163165

164166
# Enforce fixed precision FD coefficients to avoid variations in results

devito/finite_differences/tools.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from itertools import product
33

44
import numpy as np
5-
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational
5+
from sympy import S, finite_diff_weights, cacheit, sympify, Rational, Expr
66

77
from devito.logger import warning
88
from devito.tools import Tag, as_tuple
@@ -92,8 +92,8 @@ def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
9292
dims = as_tuple(dims)
9393
deriv_order = as_tuple(deriv_order)
9494
fd_order = as_tuple(fd_order)
95-
for (d, do, fo) in zip(dims, deriv_order, fd_order):
96-
expr = Derivative(expr, d, deriv_order=do, fd_order=fo, side=side, **kwargs)
95+
expr = Derivative(expr, *dims, deriv_order=deriv_order, fd_order=fd_order,
96+
side=side, **kwargs)
9797
return expr
9898

9999
all_combs = dim_with_order(dims, orders)
@@ -326,12 +326,20 @@ def make_shift_x0(shift, ndim):
326326

327327

328328
def process_weights(weights, expr, dim):
329+
from devito.symbolics import retrieve_functions
329330
if weights is None:
330331
return 0, None, False
331-
elif isinstance(weights, Function):
332+
elif isinstance(weights, Expr):
333+
w_func = retrieve_functions(weights)
334+
assert len(w_func) == 1, "Only one function expected in weights"
335+
weights = w_func[0]
332336
if len(weights.dimensions) == 1:
333337
return weights.shape[0], weights.dimensions[0], False
334-
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
338+
try:
339+
# Already a derivative
340+
wdim = {d for d in weights.dimensions if d not in expr.base.dimensions}
341+
except AttributeError:
342+
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
335343
assert len(wdim) == 1
336344
wdim = wdim.pop()
337345
shape = weights.shape

tests/test_derivatives.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim):
8787

8888
@pytest.mark.parametrize('SymbolType, derivative, dim, expected', [
8989
(Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'),
90-
(Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
90+
(Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'),
9191
(Function, ['dx2dydz'], 3,
92-
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'),
92+
'Derivative(u(x, y, z), (x, 2), y, z)'),
9393
(Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
9494
(Function, ['dx2dy', 'dz2'], 3,
95-
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'),
95+
'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'),
9696
(TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'),
97-
(TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
97+
(TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'),
9898
(TimeFunction, ['dx2', 'dy'], 3,
9999
'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
100100
(TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3,
101-
'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' +
102-
'u(t, x, y, z), x), y), (x, 2)), z), y), z)')
101+
'Derivative(Derivative(Derivative(Derivative(Derivative(' +
102+
'u(t, x, y, z), x), y), (x, 2)), z), y, z)')
103103
])
104104
def test_unevaluation(self, SymbolType, derivative, dim, expected):
105105
u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)
@@ -112,13 +112,13 @@ def test_unevaluation(self, SymbolType, derivative, dim, expected):
112112

113113
@pytest.mark.parametrize('expr,expected', [
114114
('u.dx + u.dy', 'Derivative(u, x) + Derivative(u, y)'),
115-
('u.dxdy', 'Derivative(Derivative(u, x), y)'),
115+
('u.dxdy', 'Derivative(u, x, y)'),
116116
('u.laplace',
117117
'Derivative(u, (x, 2)) + Derivative(u, (y, 2)) + Derivative(u, (z, 2))'),
118118
('(u.dx + u.dy).dx', 'Derivative(Derivative(u, x) + Derivative(u, y), x)'),
119119
('((u.dx + u.dy).dx + u.dxdy).dx',
120120
'Derivative(Derivative(Derivative(u, x) + Derivative(u, y), x) +' +
121-
' Derivative(Derivative(u, x), y), x)'),
121+
' Derivative(u, x, y), x)'),
122122
('(u**4).dx', 'Derivative(u**4, x)'),
123123
('(u/4).dx', 'Derivative(u/4, x)'),
124124
('((u.dx + v.dy).dx * v.dx).dy.dz',
@@ -741,7 +741,7 @@ def test_cross_newnest(self):
741741
grid = Grid((11, 11))
742742
f = Function(name="f", grid=grid, space_order=2)
743743

744-
assert f.dxdy == f.dx.dy
744+
assert simplify((f.dxdy - f.dx.dy).evaluate) == 0
745745

746746
def test_nested_call(self):
747747
grid = Grid((11, 11))

tests/test_symbolic_coefficients.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_function_coefficients(self):
8181

8282
assert np.all(np.isclose(f0.data[:] - f1.data[:], 0.0, atol=1e-5, rtol=0))
8383

84-
def test_function_coefficients_xderiv(self):
84+
def test_function_coefficients_xderiv_legacy(self):
8585
p = Dimension('p')
8686

8787
nstc = 8
@@ -105,6 +105,35 @@ def test_function_coefficients_xderiv(self):
105105
op = Operator(eqn)
106106
op()
107107

108+
@pytest.mark.parametrize('order', [2, 4, 6, 8])
109+
def test_function_coefficients_xderiv(self, order):
110+
p = Dimension('p')
111+
112+
grid = Grid(shape=(51, 51, 51))
113+
x, y, z = grid.dimensions
114+
115+
f = Function(name='f', grid=grid, space_order=order)
116+
w = Function(name='w', space_order=0, shape=(*grid.shape, order + 1),
117+
dimensions=(x, y, z, p))
118+
119+
expr0 = f.dx(w=w).dy(w=w).evaluate
120+
expr1 = f.dxdy(w=w).evaluate
121+
assert sp.simplify(expr0 - expr1) == 0
122+
123+
def test_coefficients_expr(self):
124+
p = Dimension('p')
125+
126+
grid = Grid(shape=(51, 51, 51))
127+
x, y, z = grid.dimensions
128+
129+
f = Function(name='f', grid=grid, space_order=4)
130+
w = Function(name='w', space_order=0, shape=(*grid.shape, 5),
131+
dimensions=(x, y, z, p))
132+
133+
expr0 = f.dx(w=w/x.spacing).evaluate
134+
expr1 = f.dx(w=w).evaluate / x.spacing
135+
assert sp.simplify(expr0 - expr1) == 0
136+
108137
def test_coefficients_w_xreplace(self):
109138
"""Test custom coefficients with an xreplace before they are applied"""
110139
grid = Grid(shape=(4, 4))

0 commit comments

Comments
 (0)