Skip to content

Commit a51c3c3

Browse files
committed
Nested derivatives are tricky
1 parent f3ed91f commit a51c3c3

2 files changed

Lines changed: 64 additions & 23 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Iterable
33
from functools import cached_property
44
from itertools import chain
5+
from warnings import warn
56

67
import sympy
78

@@ -160,18 +161,24 @@ def __new__(cls, expr, *dims, **kwargs):
160161
# Use `fd_order` if specified
161162
fd_order = kwargs.get('fd_order')
162163
if fd_order is not None:
163-
_fd_order_specified = True
164164
# If `fd_order` is specified collect these together
165165
fcounter = defaultdict(int)
166166
for d, o in zip(dims, as_tuple(fd_order)):
167-
fcounter[d] += o
168-
for d, o in dcounter.items():
169-
order = expr.time_order if getattr(d, 'is_Time', False) else expr.space_order
167+
if isinstance(d, Iterable):
168+
fcounter[d[0]] += o
169+
else:
170+
fcounter[d] += o
171+
for d, o in fcounter.items():
172+
if getattr(d, 'is_Time', False):
173+
order = expr.time_order
174+
else:
175+
order = expr.space_order
170176
if o > order:
171-
raise ValueError(f'Function does not support {d} derivative of order {o}')
177+
raise ValueError(
178+
f'Function does not support {d}-derivative with `fd_order` {o}'
179+
)
172180
fd_order = fcounter.values()
173181
else:
174-
_fd_order_specified = False
175182
# Default finite difference orders depending on input dimension (.dt or .dx)
176183
fd_order = tuple([
177184
expr.time_order
@@ -180,16 +187,16 @@ def __new__(cls, expr, *dims, **kwargs):
180187
for d in dcounter.keys()
181188
])
182189

183-
# SymPy expects the list of variable w.r.t. which we differentiate to be a list
184-
# of 2-tuple `(s, count)` where s is the entity to diff wrt and count is the order
185-
# of the derivative
190+
# SymPy expects the list of variables w.r.t. which we differentiate to be a list
191+
# of 2-tuples: `(s, count)` where:
192+
# - `s` is the entity to diff w.r.t. and
193+
# - `count` is the order of the derivative
186194
derivatives = [sympy.Tuple(d, o) for d, o in dcounter.items()]
187195

188196
# Construct the actual Derivative object
189197
obj = Differentiable.__new__(cls, expr, *derivatives)
190198
obj._dims = tuple(dcounter.keys())
191199

192-
obj._fd_order_specified = _fd_order_specified
193200
obj._fd_order = DimensionTuple(
194201
*as_tuple(fd_order),
195202
getters=obj._dims
@@ -544,17 +551,49 @@ def _eval_expand_nest(self, **hints):
544551
for d, ii in zip(
545552
chain(self.dims, expr.dims),
546553
chain(self.deriv_order, expr.deriv_order)
547-
)]
554+
)
555+
]
548556
# This is necessary as tools.abc.Reconstructable._rebuild will copy
549557
# all kwargs from the self object
550558
# TODO: This dictionary merge needs to be a lot better
551559
# EG: Don't actually expand if derivatives are incompatible
560+
new_deriv_order = tuple(chain(self.deriv_order, expr.deriv_order))
561+
# The `fd_order` may need to be reduced to construct the nested derivative
562+
dcounter = defaultdict(int)
563+
fcounter = defaultdict(int)
564+
new_fd_order = tuple(chain(self.fd_order, expr.fd_order))
565+
for d, do, fo in zip(new_dims, new_deriv_order, new_fd_order):
566+
if isinstance(d, Iterable):
567+
dcounter[d[0]] += d[1]
568+
fcounter[d[0]] += fo
569+
else:
570+
dcounter[d] += do
571+
fcounter[d] += fo
572+
for (d, do), (_, fo) in zip(dcounter.items(), fcounter.items()):
573+
if getattr(d, 'is_Time', False):
574+
dim_name = 'time'
575+
order = expr.time_order
576+
else:
577+
dim_name = 'space'
578+
order = expr.space_order
579+
if fo > order:
580+
if do > order:
581+
raise ValueError(
582+
f'Nested {do}-derivative constructed which is bigger '
583+
f'than the {dim_name}_order={order}'
584+
)
585+
else:
586+
warn(
587+
f'Nested derivative constructed with fd_order={fo}, '
588+
f'but {dim_name}_order={order}. Adjusting '
589+
f'fd_order={order} for the {d} dimension.'
590+
)
591+
fcounter[d] = order
552592
new_kwargs = {
553-
'deriv_order': tuple(chain(self.deriv_order, expr.deriv_order))
593+
'deriv_order': tuple(dcounter.values()),
594+
'fd_order': tuple(fcounter.values())
554595
}
555-
if self._fd_order_specified or expr._fd_order_specified:
556-
new_kwargs.update({'fd_order': tuple(chain(self.fd_order, expr.fd_order))})
557-
return self.func(new_expr, *new_dims, **new_kwargs)
596+
return self.func(new_expr, *dcounter.items(), **new_kwargs)
558597
else:
559598
return self
560599

tests/test_derivatives.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,9 @@ def test_fd_new_order(self, so):
378378
grid = Grid((10,))
379379
u = Function(name="u", grid=grid, space_order=so)
380380
u1 = Function(name="u", grid=grid, space_order=so//2)
381-
u2 = Function(name="u", grid=grid, space_order=2*so)
381+
u2 = Function(name="u", grid=grid, space_order=so//2 + 1)
382382
assert str(u.dx(fd_order=so//2).evaluate) == str(u1.dx.evaluate)
383-
assert str(u.dx(fd_order=2*so).evaluate) == str(u2.dx.evaluate)
383+
assert str(u.dx(fd_order=so//2 + 1).evaluate) == str(u2.dx.evaluate)
384384

385385
def test_xderiv_order(self):
386386
grid = Grid(shape=(11, 11), extent=(10., 10.))
@@ -1032,8 +1032,9 @@ def setup_class(cls):
10321032
cls.x = cls.grid.dimensions[0]
10331033
cls.u = Function(name='u', grid=cls.grid, space_order=4)
10341034

1035+
# Note that using the `.dx` shortcut method specifies the fd_order kwarg
10351036
a = cls.u.dx
1036-
cls.b = a.subs({cls.u: -5*cls.u.dx + 4*cls.u + 3})
1037+
cls.b = a.subs({cls.u: -5*cls.u.dx + 4*cls.u + 3}, postprocess=False)
10371038

10381039
def test_reconstructible(self):
10391040
''' Check that devito.Derivatives are reconstructible from func and args
@@ -1150,7 +1151,8 @@ def test_nested_orders(self):
11501151
'''
11511152
# Default fd_order
11521153
du22 = Derivative(Derivative(self.u, (self.x, 2)), (self.x, 2))
1153-
du22_expanded = du22.expand(nest=True)
1154+
with pytest.warns(UserWarning):
1155+
du22_expanded = du22.expand(nest=True)
11541156
du4 = Derivative(self.u, (self.x, 4))
11551157
assert du22_expanded == du4
11561158
assert du22_expanded.deriv_order == du4.deriv_order
@@ -1169,16 +1171,16 @@ def test_nested_orders(self):
11691171
assert du22_expanded.fd_order == du4.fd_order
11701172

11711173
# Specified fd_order greater than the space order
1172-
## When no order specified
1174+
# > When no order specified
11731175
du44 = Derivative(Derivative(self.u, (self.x, 4)), (self.x, 4))
11741176
with pytest.raises(ValueError):
1175-
du44_expanded = du44.expand(nest=True)
1177+
_ = du44.expand(nest=True)
11761178

1177-
## When order specified is too large
1179+
# > When order specified is too large
11781180
du44 = Derivative(
11791181
Derivative(self.u, (self.x, 4), fd_order=4),
11801182
(self.x, 4),
11811183
fd_order=4
11821184
)
11831185
with pytest.raises(ValueError):
1184-
du44_expanded = du44.expand(nest=True)
1186+
_ = du44.expand(nest=True)

0 commit comments

Comments
 (0)