Skip to content

Commit 525a6f0

Browse files
committed
api: fix custom coeffs with cross derivatives
1 parent aed4fa6 commit 525a6f0

3 files changed

Lines changed: 24 additions & 5 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _eval_fd(self, expr, **kwargs):
484484
assert self.method == 'FD'
485485
res = cross_derivative(expr, self.dims, self.fd_order, self.deriv_order,
486486
matvec=self.transpose, x0=x0_deriv, expand=expand,
487-
side=self.side)
487+
side=self.side, weights=self.weights)
488488
else:
489489
assert self.method == 'FD'
490490
res = generic_derivative(expr, self.dims[0], self.fd_order[0],

devito/finite_differences/tools.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
9292
dims = as_tuple(dims)
9393
deriv_order = as_tuple(deriv_order)
9494
fd_order = as_tuple(fd_order)
95-
for (d, do, fo) in zip(dims, deriv_order, fd_order):
96-
expr = Derivative(expr, d, deriv_order=do, fd_order=fo, side=side, **kwargs)
95+
expr = Derivative(expr, *dims, deriv_order=deriv_order, fd_order=fd_order,
96+
side=side, **kwargs)
9797
return expr
9898

9999
all_combs = dim_with_order(dims, orders)
@@ -331,7 +331,11 @@ def process_weights(weights, expr, dim):
331331
elif isinstance(weights, Function):
332332
if len(weights.dimensions) == 1:
333333
return weights.shape[0], weights.dimensions[0], False
334-
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
334+
try:
335+
# Already a derivative
336+
wdim = {d for d in weights.dimensions if d not in expr.base.dimensions}
337+
except AttributeError:
338+
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
335339
assert len(wdim) == 1
336340
wdim = wdim.pop()
337341
shape = weights.shape

tests/test_symbolic_coefficients.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_function_coefficients(self):
8181

8282
assert np.all(np.isclose(f0.data[:] - f1.data[:], 0.0, atol=1e-5, rtol=0))
8383

84-
def test_function_coefficients_xderiv(self):
84+
def test_function_coefficients_xderiv_legacy(self):
8585
p = Dimension('p')
8686

8787
nstc = 8
@@ -105,6 +105,21 @@ def test_function_coefficients_xderiv(self):
105105
op = Operator(eqn)
106106
op()
107107

108+
@pytest.mark.parametrize('order', [2, 4, 6, 8])
109+
def test_function_coefficients_xderiv(self, order):
110+
p = Dimension('p')
111+
112+
grid = Grid(shape=(51, 51, 51))
113+
x, y, z = grid.dimensions
114+
115+
f = Function(name='f', grid=grid, space_order=order)
116+
w = Function(name='w', space_order=0, shape=(*grid.shape, order + 1),
117+
dimensions=(x, y, z, p))
118+
119+
expr0 = f.dx(w=w).dy(w=w).evaluate
120+
expr1 = f.dxdy(w=w).evaluate
121+
assert sp.simplify(expr0 - expr1) == 0
122+
108123
def test_coefficients_w_xreplace(self):
109124
"""Test custom coefficients with an xreplace before they are applied"""
110125
grid = Grid(shape=(4, 4))

0 commit comments

Comments
 (0)