Skip to content

Commit 41fb354

Browse files
committed
misc: Refactor derivative
1 parent 35994d0 commit 41fb354

2 files changed

Lines changed: 68 additions & 87 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 67 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import OrderedDict
1+
from collections import defaultdict
22
from collections.abc import Iterable
33
from functools import cached_property
44
from itertools import chain
@@ -9,7 +9,7 @@
99
from .differentiable import Differentiable, diffify, interp_for_fd, Add
1010
from .tools import direct, transpose
1111
from .rsfd import d45
12-
from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer,
12+
from devito.tools import (as_mapper, as_tuple, frozendict, is_integer,
1313
Pickable)
1414
from devito.types.utils import DimensionTuple
1515

@@ -29,7 +29,7 @@ class Derivative(sympy.Derivative, Differentiable, Pickable):
2929
Expression for which the Derivative is produced.
3030
dims : Dimension or tuple of Dimension
3131
Dimensions w.r.t. which to differentiate.
32-
fd_order : int or tuple of int, optional, default=1
32+
fd_order : int, tuple of int or dict of {Dimension: int}, optional, default=1
3333
Coefficient discretization order. Note: this impacts the width of
3434
the resulting stencil.
3535
deriv_order: int or tuple of int, optional, default=1
@@ -93,6 +93,11 @@ class Derivative(sympy.Derivative, Differentiable, Pickable):
9393
'x0', 'method', 'weights')
9494

9595
def __new__(cls, expr, *dims, **kwargs):
96+
# TODO: Delete this
97+
if kwargs.get('preprocessed', False):
98+
from warnings import warn
99+
warn('I removed the `preprocessed` kwarg')
100+
96101
if type(expr) is sympy.Derivative:
97102
raise ValueError("Cannot nest sympy.Derivative with devito.Derivative")
98103
if not isinstance(expr, Differentiable):
@@ -103,14 +108,62 @@ def __new__(cls, expr, *dims, **kwargs):
103108
if isinstance(expr, sympy.Number):
104109
return 0
105110

106-
new_dims, orders, fd_o, var_count = cls._process_kwargs(expr, *dims, **kwargs)
111+
# Validate `dims`. It can be:
112+
# - a single Dimension ie: x
113+
# - an iterable of Dimensions ie: (x, y)
114+
# - a single tuple of Dimension and order ie: (x, 2)
115+
# - or an iterable of Dimension, order ie: ((x, 2), (y, 2))
116+
if len(dims) == 0:
117+
raise ValueError('Expected Dimension w.r.t. which to differentiate')
118+
elif len(dims) == 1 and isinstance(dims[0], Iterable) and len(dims[0]) != 2:
119+
# Iterable of Dimensions
120+
raise ValueError(f'Expected `(dim, deriv_order)`, got {dims[0]}')
121+
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]):
122+
# special case of single dimension and order
123+
dims = (dims, )
124+
125+
# Use `deriv_order` if specified
126+
deriv_order = kwargs.get('deriv_order', (1,)*len(dims))
127+
if not isinstance(deriv_order, Iterable):
128+
deriv_order = as_tuple(deriv_order)
129+
if len(deriv_order) != len(dims):
130+
raise ValueError(
131+
'Length of `deriv_order` does not match the length of dimensions'
132+
)
133+
134+
# Count the number of derivatives for each dimension
135+
dcounter = defaultdict(int)
136+
for d, o in zip(dims, deriv_order):
137+
if isinstance(d, Iterable):
138+
dcounter[d[0]] += d[1]
139+
else:
140+
dcounter[d] += o
107141

108-
# Construct the actual Derivative object
109-
obj = Differentiable.__new__(cls, expr, *var_count)
110-
obj._dims = tuple(OrderedDict.fromkeys(new_dims))
142+
# Default finite difference orders depending on input dimension (.dt or .dx)
143+
default_fdo = tuple([
144+
expr.time_order
145+
if getattr(d, 'is_Time', False)
146+
else expr.space_order
147+
for d in dcounter.keys()
148+
])
111149

112-
obj._fd_order = DimensionTuple(*as_tuple(fd_o), getters=obj._dims)
113-
obj._deriv_order = DimensionTuple(*as_tuple(orders), getters=obj._dims)
150+
# SymPy expects the list of variable w.r.t. which we differentiate to be a list
151+
# of 2-tuple `(s, count)` where s is the entity to diff wrt and count is the order
152+
# of the derivative
153+
derivatives = [sympy.Tuple(d, o) for d, o in dcounter.items()]
154+
155+
# Construct the actual Derivative object
156+
obj = Differentiable.__new__(cls, expr, *derivatives)
157+
obj._dims = tuple(dcounter.keys())
158+
159+
obj._fd_order = DimensionTuple(
160+
*as_tuple(kwargs.get('fd_order', default_fdo)),
161+
getters=obj._dims
162+
)
163+
obj._deriv_order = DimensionTuple(
164+
*as_tuple(dcounter.values()),
165+
getters=obj._dims
166+
)
114167
obj._side = kwargs.get("side")
115168
obj._transpose = kwargs.get("transpose", direct)
116169
obj._method = kwargs.get("method", 'FD')
@@ -131,78 +184,6 @@ def __new__(cls, expr, *dims, **kwargs):
131184

132185
return obj
133186

134-
@classmethod
135-
def _process_kwargs(cls, expr, *dims, **kwargs):
136-
"""
137-
Process arguments for the construction of a Derivative
138-
"""
139-
# Skip costly processing if constructiong from preprocessed
140-
if kwargs.get('preprocessed', False):
141-
fd_orders = kwargs.get('fd_order')
142-
deriv_orders = kwargs.get('deriv_order')
143-
if len(dims) == 1:
144-
if isinstance(dims[0], Iterable):
145-
assert dims[0][1] == deriv_orders[0]
146-
dims = tuple([dims[0][0]]*max(1, deriv_orders[0]))
147-
else:
148-
dims = tuple([dims[0]]*max(1, deriv_orders[0]))
149-
variable_count = [sympy.Tuple(s, dims.count(s))
150-
for s in filter_ordered(dims)]
151-
return dims, deriv_orders, fd_orders, variable_count
152-
153-
# Check `dims`. It can be a single Dimension, an iterable of Dimensions, or even
154-
# an iterable of 2-tuple (Dimension, deriv_order)
155-
if len(dims) == 0:
156-
raise ValueError("Expected Dimension w.r.t. which to differentiate")
157-
elif len(dims) == 1:
158-
if isinstance(dims[0], Iterable):
159-
# Iterable of Dimensions
160-
if len(dims[0]) != 2:
161-
raise ValueError("Expected `(dim, deriv_order)`, got %s" % dims[0])
162-
orders = kwargs.get('deriv_order', dims[0][1])
163-
if dims[0][1] != orders:
164-
raise ValueError("Two different values of `deriv_order`")
165-
new_dims = tuple([dims[0][0]]*max(1, dims[0][1]))
166-
else:
167-
# Single Dimension
168-
orders = kwargs.get('deriv_order', 1)
169-
if isinstance(orders, Iterable):
170-
orders = orders[0]
171-
new_dims = tuple([dims[0]]*max(1, orders))
172-
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]):
173-
# special case of single dimension and order
174-
orders = dims[1]
175-
new_dims = tuple([dims[0]]*max(1, orders))
176-
else:
177-
# Iterable of 2-tuple, e.g. ((x, 2), (y, 3))
178-
new_dims = []
179-
orders = []
180-
d_ord = kwargs.get('deriv_order', tuple([1]*len(dims)))
181-
for d, o in zip(dims, d_ord):
182-
if isinstance(d, Iterable):
183-
new_dims.extend([d[0]]*max(1, d[1]))
184-
orders.append(d[1])
185-
else:
186-
new_dims.extend([d]*max(1, o))
187-
orders.append(o)
188-
new_dims = as_tuple(new_dims)
189-
orders = as_tuple(orders)
190-
191-
# Finite difference orders depending on input dimension (.dt or .dx)
192-
odims = filter_ordered(new_dims)
193-
fd_orders = kwargs.get('fd_order', tuple([expr.time_order if
194-
getattr(d, 'is_Time', False) else
195-
expr.space_order for d in odims]))
196-
if len(odims) == 1 and isinstance(fd_orders, Iterable):
197-
fd_orders = fd_orders[0]
198-
199-
# SymPy expects the list of variable w.r.t. which we differentiate to be a list
200-
# of 2-tuple `(s, count)` where s is the entity to diff wrt and count is the order
201-
# of the derivative
202-
variable_count = [sympy.Tuple(s, new_dims.count(s))
203-
for s in odims]
204-
return new_dims, orders, fd_orders, variable_count
205-
206187
@classmethod
207188
def _process_x0(cls, dims, **kwargs):
208189
try:
@@ -268,7 +249,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, **kwargs):
268249
if fd_order is not None:
269250
rkw['fd_order'] = self._filter_dims(_fd_order, as_tuple=True)
270251

271-
return self._rebuild(**rkw)
252+
return self._rebuild(*self.args, **rkw)
272253

273254
def _rebuild(self, *args, **kwargs):
274255
if not args:
@@ -331,7 +312,7 @@ def _xreplace(self, subs):
331312
expr = self.expr.xreplace(dsubs)
332313

333314
subs = self._ppsubs + (subs,) # Postponed substitutions
334-
return self._rebuild(subs=subs, expr=expr), True
315+
return self._rebuild(expr, *self.args[1:], subs=subs), True
335316

336317
@cached_property
337318
def _metadata(self):
@@ -405,7 +386,7 @@ def T(self):
405386
else:
406387
adjoint = direct
407388

408-
return self._rebuild(transpose=adjoint)
389+
return self._rebuild(*self.args, transpose=adjoint)
409390

410391
def _eval_at(self, func):
411392
"""
@@ -438,12 +419,12 @@ def _eval_at(self, func):
438419
# in most equation with div(a * u) for example. The expression is re-centered
439420
# at the highest priority index (see _gather_for_diff) to compute the
440421
# derivative at x0.
441-
return self._rebuild(expr=self.expr._gather_for_diff, x0=x0)
422+
return self._rebuild(self.expr._gather_for_diff, *self.args[1:], x0=x0)
442423
else:
443424
# For every other cases, that has more functions or more complexe arithmetic,
444425
# there is not actual way to decide what to do so it’s as safe to use
445426
# the expression as is.
446-
return self._rebuild(x0=x0)
427+
return self._rebuild(*self.args, x0=x0)
447428

448429
def _evaluate(self, **kwargs):
449430
# Evaluate finite-difference.

devito/finite_differences/differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ def interp_for_fd(expr, x0, **kwargs):
11611161
@interp_for_fd.register(sympy.Derivative)
11621162
def _(expr, x0, **kwargs):
11631163
x0_expr = {d: v for d, v in x0.items() if d not in expr.dims}
1164-
return expr.func(expr=interp_for_fd(expr.expr, x0_expr, **kwargs))
1164+
return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs), *expr.args[1:])
11651165

11661166

11671167
@interp_for_fd.register(sympy.Expr)

0 commit comments

Comments
 (0)