Skip to content

Commit b32ad78

Browse files
committed
sympy: Add product rule expansion hint and tests
1 parent 4de0237 commit b32ad78

2 files changed

Lines changed: 52 additions & 4 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sympy
88

99
from .finite_difference import generic_derivative, cross_derivative
10-
from .differentiable import Differentiable, diffify, interp_for_fd, Add
10+
from .differentiable import Differentiable, diffify, interp_for_fd, Add, Mul
1111
from .tools import direct, transpose
1212
from .rsfd import d45
1313
from devito.tools import (as_mapper, as_tuple, frozendict, is_integer,
@@ -695,4 +695,27 @@ def _eval_expand_product_rule(self, **hints):
695695
resultant expression for higher derivatives and mixed derivatives is much
696696
more difficult to implement.
697697
"""
698-
raise NotImplementedError('Product rule expansion has not been written')
698+
if self.expr.is_Mul and len(self.dims) == 1:
699+
if self.deriv_order == (1,):
700+
return Add(*[
701+
Mul(*self.expr.args[:ii], self.func(m), *self.expr.args[ii + 1:])
702+
for ii, m in enumerate(self.expr.args)
703+
])
704+
elif self.deriv_order == (2,) and len(self.expr.args) == 2:
705+
return Add(
706+
Mul(self.func(self.expr.args[0]), self.expr.args[1]),
707+
Mul(
708+
2,
709+
self.func(self.expr.args[0], deriv_order=1),
710+
self.func(self.expr.args[1], deriv_order=1)
711+
),
712+
Mul(self.expr.args[0], self.func(self.expr.args[1]))
713+
)
714+
else:
715+
# Note: It _is_ possible to implement the product rule for many
716+
# more cases, but the number of terms in the resultant expression
717+
# will grow.
718+
return self
719+
else:
720+
return self
721+

tests/test_derivatives.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from devito import (Grid, Function, TimeFunction, Eq, Operator, NODE, cos, sin,
66
ConditionalDimension, left, right, centered, div, grad)
7-
from devito.finite_differences import Derivative, Differentiable
7+
from devito.finite_differences import Derivative, Differentiable, diffify
88
from devito.finite_differences.differentiable import (Add, EvalDerivative, IndexSum,
99
IndexDerivative, Weights,
1010
DiffDerivative)
@@ -1030,7 +1030,12 @@ class TestExpansion:
10301030
def setup_class(cls):
10311031
cls.grid = Grid(shape=(11,), extent=(1,))
10321032
cls.x = cls.grid.dimensions[0]
1033-
cls.u = Function(name='u', grid=cls.grid, space_order=4)
1033+
cls.u = TimeFunction(name='u', grid=cls.grid, space_order=4)
1034+
cls.t = cls.u.time_dim
1035+
1036+
cls.v = Function(name='v', grid=cls.grid, space_order=4)
1037+
cls.w = Function(name='w', grid=cls.grid, space_order=4)
1038+
cls.f = Function(name='f', dimensions=(cls.t,), shape=(5,))
10341039

10351040
# Note that using the `.dx` shortcut method specifies the fd_order kwarg
10361041
a = cls.u.dx
@@ -1140,6 +1145,10 @@ def test_expand_mul(self):
11401145
expanded = Derivative(4*self.u - 5*Derivative(self.u, self.x) + 3, self.x)
11411146
assert self.b.expand() == expanded
11421147

1148+
expr = self.u.dx.subs({self.u: 5*self.f*self.u*self.v*self.w})
1149+
expanded = 5*self.f*Derivative(self.u*self.v*self.w, self.x)
1150+
assert diffify(expr.expand()) == expanded
1151+
11431152
def test_expand_add(self):
11441153
"""
11451154
Check linearity
@@ -1194,3 +1203,19 @@ def test_nested_orders(self):
11941203
)
11951204
with pytest.raises(ValueError):
11961205
_ = du44.expand(nest=True)
1206+
1207+
def test_expand_product_rule(self):
1208+
"""
1209+
Check the two implemented cases for product rule expansion
1210+
"""
1211+
expr = self.u.dx.subs({self.u: 5*self.f*self.u*self.v*self.w}, postprocess=False)
1212+
expanded = 5*self.f*self.v*self.w*Derivative(self.u, self.x) \
1213+
+ 5*self.f*self.v*self.u*Derivative(self.w, self.x) \
1214+
+ 5*self.f*self.w*self.u*Derivative(self.v, self.x)
1215+
assert diffify(expr.expand(product_rule=True)) == expanded
1216+
1217+
expr = self.u.dx2.subs({self.u: 5*self.f*self.u*self.v}, postprocess=False)
1218+
expanded = 5*self.f*self.v*Derivative(self.u, (self.x, 2)) \
1219+
+ 10*self.f*Derivative(self.v, self.x)*Derivative(self.u, self.x) \
1220+
+ 5*self.f*self.u*Derivative(self.v, (self.x, 2))
1221+
assert diffify(expr.expand(product_rule=True)) == expanded

0 commit comments

Comments
 (0)