Skip to content

Commit 4de0237

Browse files
committed
misc: More reviewer comments
1 parent 8014a41 commit 4de0237

2 files changed

Lines changed: 145 additions & 84 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 144 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,87 @@ def __new__(cls, expr, *dims, **kwargs):
9999
from warnings import warn
100100
warn('I removed the `preprocessed` kwarg')
101101

102+
# Validate the input arguments `expr`, `dims` and `deriv_order`
103+
expr = cls._validate_expr(expr)
104+
dims = cls._validate_dims(dims)
105+
deriv_order = cls._validate_deriv_order(kwargs.get('deriv_order'), dims)
106+
# Count the derivatives w.r.t. each variable
107+
dcounter = cls._count_derivatives(deriv_order, dims)
108+
109+
# It's possible that the expr is a `sympy.Number` at this point, which
110+
# has derivative 0, unless we're taking a 0th derivative.
111+
if isinstance(expr, sympy.Number):
112+
if any(dcounter.values()):
113+
return 0
114+
else:
115+
return expr
116+
117+
# Validate the finite difference order `fd_order`
118+
fd_order = cls._validate_fd_order(kwargs.get('fd_order'), expr, dims, dcounter)
119+
120+
# SymPy expects the list of variables w.r.t. which we differentiate to be a list
121+
# of 2-tuples: `(s, count)` where:
122+
# - `s` is the entity to diff w.r.t. and
123+
# - `count` is the order of the derivative
124+
derivatives = [sympy.Tuple(d, o) for d, o in dcounter.items()]
125+
126+
# Construct the actual Derivative object
127+
obj = Differentiable.__new__(cls, expr, *derivatives)
128+
obj._dims = tuple(dcounter.keys())
129+
130+
obj._fd_order = DimensionTuple(
131+
*as_tuple(fd_order),
132+
getters=obj._dims
133+
)
134+
obj._deriv_order = DimensionTuple(
135+
*as_tuple(dcounter.values()),
136+
getters=obj._dims
137+
)
138+
obj._side = kwargs.get("side")
139+
obj._transpose = kwargs.get("transpose", direct)
140+
obj._method = kwargs.get("method", 'FD')
141+
obj._weights = cls._process_weights(**kwargs)
142+
143+
ppsubs = kwargs.get("subs", kwargs.get("_ppsubs", []))
144+
processed = []
145+
if ppsubs:
146+
for i in ppsubs:
147+
try:
148+
processed.append(frozendict(i))
149+
except AttributeError:
150+
# E.g. `i` is a Transform object
151+
processed.append(i)
152+
obj._ppsubs = tuple(processed)
153+
154+
obj._x0 = cls._process_x0(obj._dims, **kwargs)
155+
156+
return obj
157+
158+
@staticmethod
159+
def _validate_expr(expr):
160+
"""
161+
Validate the provided `expr`. It must be of "differentiable" type or
162+
convertible to "differentiable" type.
163+
"""
102164
if type(expr) is sympy.Derivative:
103165
raise ValueError("Cannot nest sympy.Derivative with devito.Derivative")
104166
if not isinstance(expr, Differentiable):
105167
try:
106168
expr = diffify(expr)
107169
except Exception as e:
108-
raise ValueError("`expr` must be a Differentiable object") from e
109-
110-
# Validate `dims`. It can be:
111-
# - a single Dimension ie: x
112-
# - an iterable of Dimensions ie: (x, y)
113-
# - a single tuple of Dimension and order ie: (x, 2)
114-
# - or an iterable of Dimension, order ie: ((x, 2), (y, 2))
115-
# - any combination of the above ie: ((x, 2), y, x, (z, 3))
170+
raise ValueError("`expr` must be a `Differentiable` type object") from e
171+
return expr
172+
173+
@staticmethod
174+
def _validate_dims(dims):
175+
"""
176+
Validate `dims`. It can be:
177+
- a single Dimension ie: x
178+
- an iterable of Dimensions ie: (x, y)
179+
- a single tuple of Dimension and order ie: (x, 2)
180+
- or an iterable of Dimension, order ie: ((x, 2), (y, 2))
181+
- any combination of the above ie: ((x, 2), y, x, (z, 3))
182+
"""
116183
if len(dims) == 0:
117184
raise ValueError('Expected Dimension w.r.t. which to differentiate')
118185
elif len(dims) == 1 and isinstance(dims[0], Iterable) and len(dims[0]) != 2:
@@ -121,9 +188,16 @@ def __new__(cls, expr, *dims, **kwargs):
121188
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]):
122189
# special case of single dimension and order
123190
dims = (dims, )
191+
return dims
124192

125-
# Use `deriv_order` if specified
126-
deriv_order = kwargs.get('deriv_order', (1,)*len(dims))
193+
@staticmethod
194+
def _validate_deriv_order(deriv_order, dims):
195+
"""
196+
If provided `deriv_order` must correspond to the provided dimensions.
197+
Requires dims to validate or construct the default.
198+
"""
199+
if deriv_order is None:
200+
deriv_order = (1,)*len(dims)
127201
if not isinstance(deriv_order, Iterable):
128202
deriv_order = as_tuple(deriv_order)
129203
if len(deriv_order) != len(dims):
@@ -135,10 +209,15 @@ def __new__(cls, expr, *dims, **kwargs):
135209
'Invalid type in `deriv_order`, all elements must be non-negative'
136210
'Python `int`s'
137211
)
212+
return deriv_order
138213

139-
# Count the number of derivatives for each dimension
214+
@staticmethod
215+
def _count_derivatives(deriv_order, dims):
216+
"""
217+
Count the number of derivatives for each dimension.
218+
"""
140219
dcounter = defaultdict(int)
141-
for d, o in zip(dims, deriv_order):
220+
for d, o in zip(dims, deriv_order, strict=True):
142221
if isinstance(d, Iterable):
143222
if not is_integer(d[1]) or d[1] < 0:
144223
raise TypeError(
@@ -149,25 +228,27 @@ def __new__(cls, expr, *dims, **kwargs):
149228
dcounter[d[0]] += d[1]
150229
else:
151230
dcounter[d] += o
231+
return dcounter
152232

153-
# It's possible that the expr is a `sympy.Number` at this point, which
154-
# has derivative 0, unless we're taking a 0th derivative.
155-
if isinstance(expr, sympy.Number):
156-
if any(dcounter.values()):
157-
return 0
158-
else:
159-
return expr
160-
161-
# Use `fd_order` if specified
162-
fd_order = kwargs.get('fd_order')
233+
@staticmethod
234+
def _validate_fd_order(fd_order, expr, dims, dcounter):
235+
"""
236+
If provided `fd_order` must correspond to the provided dimensions.
237+
Required `expr`, `dims` and the derivative counter to validate.
238+
If not provided the maximum supported order will be used.
239+
"""
163240
if fd_order is not None:
164-
# If `fd_order` is specified collect these together
241+
# If `fd_order` is specified validate
165242
fcounter = defaultdict(int)
166-
for d, o in zip(dims, as_tuple(fd_order)):
243+
# First create a dictionary mapping variable wrt which to differentiate
244+
# to the `fd_order`
245+
for d, o in zip(dims, as_tuple(fd_order), strict=True):
167246
if isinstance(d, Iterable):
168247
fcounter[d[0]] += o
169248
else:
170249
fcounter[d] += o
250+
# Second validate that the `fd_order` is supported by the space or
251+
# time order
171252
for d, o in fcounter.items():
172253
if getattr(d, 'is_Time', False):
173254
order = expr.time_order
@@ -189,44 +270,7 @@ def __new__(cls, expr, *dims, **kwargs):
189270
else expr.space_order
190271
for d in dcounter.keys()
191272
)
192-
193-
# SymPy expects the list of variables w.r.t. which we differentiate to be a list
194-
# of 2-tuples: `(s, count)` where:
195-
# - `s` is the entity to diff w.r.t. and
196-
# - `count` is the order of the derivative
197-
derivatives = [sympy.Tuple(d, o) for d, o in dcounter.items()]
198-
199-
# Construct the actual Derivative object
200-
obj = Differentiable.__new__(cls, expr, *derivatives)
201-
obj._dims = tuple(dcounter.keys())
202-
203-
obj._fd_order = DimensionTuple(
204-
*as_tuple(fd_order),
205-
getters=obj._dims
206-
)
207-
obj._deriv_order = DimensionTuple(
208-
*as_tuple(dcounter.values()),
209-
getters=obj._dims
210-
)
211-
obj._side = kwargs.get("side")
212-
obj._transpose = kwargs.get("transpose", direct)
213-
obj._method = kwargs.get("method", 'FD')
214-
obj._weights = cls._process_weights(**kwargs)
215-
216-
ppsubs = kwargs.get("subs", kwargs.get("_ppsubs", []))
217-
processed = []
218-
if ppsubs:
219-
for i in ppsubs:
220-
try:
221-
processed.append(frozendict(i))
222-
except AttributeError:
223-
# E.g. `i` is a Transform object
224-
processed.append(i)
225-
obj._ppsubs = tuple(processed)
226-
227-
obj._x0 = cls._process_x0(obj._dims, **kwargs)
228-
229-
return obj
273+
return fd_order
230274

231275
@classmethod
232276
def _process_x0(cls, dims, **kwargs):
@@ -274,7 +318,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, **kwargs):
274318
assert self.ndims == 1
275319
_fd_order = {self.dims[0]: fd_order}
276320
except AttributeError:
277-
raise TypeError("fd_order incompatible with dimensions")
321+
raise TypeError("fd_order incompatible with dimensions") from None
278322

279323
if isinstance(self.expr, Derivative):
280324
# In case this was called on a perfect cross-derivative `u.dxdy`
@@ -538,48 +582,60 @@ def _eval_fd(self, expr, **kwargs):
538582
return res
539583

540584
def _eval_expand_nest(self, **hints):
541-
''' Expands nested derivatives
585+
"""
586+
Expands nested derivatives
542587
`Derivative(Derivative(f(x), (x, b)), (x, a))
543588
--> Derivative(f(x), (x, a+b))`
544589
`Derivative(Derivative(f(x), (y, b)), (x, a))
545590
--> Derivative(f(x), (x, a), (y, b))`
546591
Note that this is not always a valid expansion depending on the kwargs
547592
used to construct the derivative.
548-
'''
593+
"""
549594
if not isinstance(self.expr, self.__class__):
550595
return self
551596

597+
# This is necessary as tools.abc.Reconstructable._rebuild will copy
598+
# all kwargs from the self object. Need to enssure that the nest is not
599+
# actually expanded if derivatives are incompatible.
600+
# The nested derivative is evaluated by:
601+
# 1. Chaining together the variables with which to differentiate wrt
552602
new_expr = self.expr.args[0]
553603
new_dims = [
554604
(d, ii)
555605
for d, ii in zip(
556606
chain(self.dims, self.expr.dims),
557-
chain(self.deriv_order, self.expr.deriv_order)
607+
chain(self.deriv_order, self.expr.deriv_order),
608+
strict=True
558609
)
559610
]
560-
# This is necessary as tools.abc.Reconstructable._rebuild will copy
561-
# all kwargs from the self object
562-
# TODO: This dictionary merge needs to be a lot better
563-
# EG: Don't actually expand if derivatives are incompatible
611+
612+
# 2. Count the number of derivatives to take wrt each variable as well as
613+
# the finite difference order to use by iterating over the chained lists of
614+
# variables.
564615
new_deriv_order = tuple(chain(self.deriv_order, self.expr.deriv_order))
565-
# The `fd_order` may need to be reduced to construct the nested derivative
616+
new_fd_order = tuple(chain(self.fd_order, self.expr.fd_order))
566617
dcounter = defaultdict(int)
567618
fcounter = defaultdict(int)
568-
new_fd_order = tuple(chain(self.fd_order, self.expr.fd_order))
569-
for d, do, fo in zip(new_dims, new_deriv_order, new_fd_order):
619+
for d, do, fo in zip(new_dims, new_deriv_order, new_fd_order, strict=True):
570620
if isinstance(d, Iterable):
571621
dcounter[d[0]] += d[1]
572622
fcounter[d[0]] += fo
573623
else:
574624
dcounter[d] += do
575625
fcounter[d] += fo
576-
for (d, do), (_, fo) in zip(dcounter.items(), fcounter.items()):
626+
627+
# 3. Validate that the number of derivatives taken and the `fd_order` are
628+
# smaller than or equal to the corresponding space or time order that the
629+
# function supports.
630+
for (d, do), (_, fo) in zip(dcounter.items(), fcounter.items(), strict=True):
577631
if getattr(d, 'is_Time', False):
578632
dim_name = 'time'
579633
order = self.expr.time_order
580634
else:
581635
dim_name = 'space'
582636
order = self.expr.space_order
637+
# The `fd_order` may need to be reduced to construct the nested derivative
638+
# in this case we only emit a warning
583639
if fo > order:
584640
if do > order:
585641
raise ValueError(
@@ -593,45 +649,50 @@ def _eval_expand_nest(self, **hints):
593649
f'fd_order={order} for the {d} dimension.'
594650
)
595651
fcounter[d] = order
652+
653+
# 4. Finally, construct the new derivative object with the updated counts
654+
# and kwargs.
596655
new_kwargs = {
597656
'deriv_order': tuple(dcounter.values()),
598657
'fd_order': tuple(fcounter.values())
599658
}
600659
return self.func(new_expr, *dcounter.items(), **new_kwargs)
601660

602661
def _eval_expand_mul(self, **hints):
603-
''' Expands products, moving independent terms outside the derivative
662+
"""
663+
Expands products, moving independent terms outside the derivative
604664
`Derivative(C·f(x)·g(c, y), x)
605665
--> C·g(y)·Derivative(f(x), x)`
606-
'''
666+
"""
607667
if self.expr.is_Mul:
608668
ind, dep = self.expr.as_independent(*self.dims, as_Add=False)
609-
return ind*self.func(dep, *self.args[1:])
669+
return ind*self.func(dep)
610670
else:
611671
return self
612672

613673
def _eval_expand_add(self, **hints):
614-
''' Expands sums, using linearity of derivative
674+
"""
675+
Expands sums, using linearity of derivative
615676
`Derivative(f(x) + g(x), x)
616677
--> Derivative(f(x), x) + Derivative(g(x), x)`
617-
'''
678+
"""
618679
if self.expr.is_Add:
619680
ind, dep = self.expr.as_independent(*self.dims, as_Add=True)
620681
if dep.is_Add:
621682
return Add(*[self.func(s, *self.args[1:]) for s in dep.args])
622683
else:
623-
return self.func(dep, *self.args[1:])
684+
return self.func(dep)
624685
else:
625686
return self
626687

627688
def _eval_expand_product_rule(self, **hints):
628-
''' Expands products, of functions of the dependent variable
689+
"""
690+
Expands products, of functions of the dependent variable
629691
`Derivative(f(x)·g(x), x)
630692
--> Derivative(f(x), x)·g(x) + f(x)·Derivative(g(x), x)`
631693
This is only implemented for first derivatives with an arbitrary number
632694
of multiplicands and second derivatives with two multiplicands. The
633695
resultant expression for higher derivatives and mixed derivatives is much
634696
more difficult to implement.
635-
'''
636-
# if self.expr.is_Mul:
697+
"""
637698
raise NotImplementedError('Product rule expansion has not been written')

devito/finite_differences/differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ def interp_for_fd(expr, x0, **kwargs):
10941094
@interp_for_fd.register(sympy.Derivative)
10951095
def _(expr, x0, **kwargs):
10961096
x0_expr = {d: v for d, v in x0.items() if d not in expr.dims}
1097-
return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs), *expr.args[1:])
1097+
return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs))
10981098

10991099

11001100
@interp_for_fd.register(sympy.Expr)

0 commit comments

Comments
 (0)