Skip to content

Commit b56e928

Browse files
committed
dsl: Fix nested derivative expansion
1 parent 41fb354 commit b56e928

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __new__(cls, expr, *dims, **kwargs):
113113
# - an iterable of Dimensions ie: (x, y)
114114
# - a single tuple of Dimension and order ie: (x, 2)
115115
# - or an iterable of Dimension, order ie: ((x, 2), (y, 2))
116+
# - any combination of the above ie: ((x, 2), y, x, (z, 3))
116117
if len(dims) == 0:
117118
raise ValueError('Expected Dimension w.r.t. which to differentiate')
118119
elif len(dims) == 1 and isinstance(dims[0], Iterable) and len(dims[0]) != 2:
@@ -496,11 +497,19 @@ def _eval_fd(self, expr, **kwargs):
496497
def _eval_expand_nest(self, **hints):
497498
expr = self.args[0]
498499
if isinstance(expr, self.__class__):
499-
return self.func(expr.args[0], *[(d, ii)
500+
new_expr = expr.args[0]
501+
new_dims = [
502+
(d, ii)
500503
for d, ii in zip(
501504
chain(self.dims, expr.dims),
502505
chain(self.deriv_order, expr.deriv_order)
503-
)])
506+
)]
507+
# This is necessary as tools.abc.Reconstructable._rebuild will copy
508+
# all kwargs from the self object
509+
# TODO: This dictionary merge needs to be a lot better
510+
# EG: Don't actually expand if derivatives are incompatible
511+
new_kwargs = {'deriv_order': tuple(chain(self.deriv_order, expr.deriv_order))}
512+
return self.func(new_expr, *new_dims, **new_kwargs)
504513
else:
505514
return self
506515

0 commit comments

Comments
 (0)