Skip to content

Commit 532d25c

Browse files
committed
compiler: Patch ideriv._subs
1 parent f476873 commit 532d25c

2 files changed

Lines changed: 33 additions & 0 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,20 @@ def _evaluate(self, **kwargs):
948948

949949
return EvalDerivative(*expr.args, base=self.base)
950950

951+
def _subs(self, old, new, **hints):
952+
# We have to work around SymPy's weak implementation of `subs` when
953+
# it gets to replacing sub-operations such as `a*b*c` (i.e., potentially
954+
# `self`'s `base`) within say `a*b*c*w[i0]` (i.e., the corresponding
955+
# `self.expr`), because depending on the complexity of `a/b/c`, SymPy
956+
# may fail to identify the sub-expression to be replaced (note: if
957+
# `a/b/c` are atoms or Indexeds, it's generally fine)
958+
959+
if not old.is_Mul or \
960+
old is not self.base:
961+
return super()._subs(old, new, **hints)
962+
963+
return self._rebuild(new * self.weights)
964+
951965

952966
class DiffDerivative(IndexDerivative, DifferentiableOp):
953967
pass

tests/test_symbolics.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,25 @@ def test_canonical_ordering_of_weights():
717717
assert ccode(cf*wi) == 'f[x][y + i0][z].x*w0[i0]'
718718

719719

720+
def test_ideriv_subs_complex():
721+
grid = Grid(shape=(3, 3))
722+
x, _ = grid.dimensions
723+
724+
f = Function(name='f', grid=grid, space_order=4)
725+
g = f.func(name='g')
726+
h = f.func(name='h')
727+
b = f.func(name='b')
728+
729+
ideriv = (f*g*h).dx._evaluate(expand=False)
730+
731+
i0, = ideriv.dimensions
732+
base = ideriv.base
733+
734+
v = ideriv._subs(base, b._subs(x, x + i0))
735+
736+
assert str(v) == 'DiffDerivative(w(i0)*b(x + i0, y), (i0))'
737+
738+
720739
def test_symbolic_printing():
721740
b = Symbol('b')
722741

0 commit comments

Comments
 (0)