@@ -99,20 +99,87 @@ def __new__(cls, expr, *dims, **kwargs):
9999 from warnings import warn
100100 warn ('I removed the `preprocessed` kwarg' )
101101
102+ # Validate the input arguments `expr`, `dims` and `deriv_order`
103+ expr = cls ._validate_expr (expr )
104+ dims = cls ._validate_dims (dims )
105+ deriv_order = cls ._validate_deriv_order (kwargs .get ('deriv_order' ), dims )
106+ # Count the derivatives w.r.t. each variable
107+ dcounter = cls ._count_derivatives (deriv_order , dims )
108+
109+ # It's possible that the expr is a `sympy.Number` at this point, which
110+ # has derivative 0, unless we're taking a 0th derivative.
111+ if isinstance (expr , sympy .Number ):
112+ if any (dcounter .values ()):
113+ return 0
114+ else :
115+ return expr
116+
117+ # Validate the finite difference order `fd_order`
118+ fd_order = cls ._validate_fd_order (kwargs .get ('fd_order' ), expr , dims , dcounter )
119+
120+ # SymPy expects the list of variables w.r.t. which we differentiate to be a list
121+ # of 2-tuples: `(s, count)` where:
122+ # - `s` is the entity to diff w.r.t. and
123+ # - `count` is the order of the derivative
124+ derivatives = [sympy .Tuple (d , o ) for d , o in dcounter .items ()]
125+
126+ # Construct the actual Derivative object
127+ obj = Differentiable .__new__ (cls , expr , * derivatives )
128+ obj ._dims = tuple (dcounter .keys ())
129+
130+ obj ._fd_order = DimensionTuple (
131+ * as_tuple (fd_order ),
132+ getters = obj ._dims
133+ )
134+ obj ._deriv_order = DimensionTuple (
135+ * as_tuple (dcounter .values ()),
136+ getters = obj ._dims
137+ )
138+ obj ._side = kwargs .get ("side" )
139+ obj ._transpose = kwargs .get ("transpose" , direct )
140+ obj ._method = kwargs .get ("method" , 'FD' )
141+ obj ._weights = cls ._process_weights (** kwargs )
142+
143+ ppsubs = kwargs .get ("subs" , kwargs .get ("_ppsubs" , []))
144+ processed = []
145+ if ppsubs :
146+ for i in ppsubs :
147+ try :
148+ processed .append (frozendict (i ))
149+ except AttributeError :
150+ # E.g. `i` is a Transform object
151+ processed .append (i )
152+ obj ._ppsubs = tuple (processed )
153+
154+ obj ._x0 = cls ._process_x0 (obj ._dims , ** kwargs )
155+
156+ return obj
157+
158+ @staticmethod
159+ def _validate_expr (expr ):
160+ """
161+ Validate the provided `expr`. It must be of "differentiable" type or
162+ convertible to "differentiable" type.
163+ """
102164 if type (expr ) is sympy .Derivative :
103165 raise ValueError ("Cannot nest sympy.Derivative with devito.Derivative" )
104166 if not isinstance (expr , Differentiable ):
105167 try :
106168 expr = diffify (expr )
107169 except Exception as e :
108- raise ValueError ("`expr` must be a Differentiable object" ) from e
109-
110- # Validate `dims`. It can be:
111- # - a single Dimension ie: x
112- # - an iterable of Dimensions ie: (x, y)
113- # - a single tuple of Dimension and order ie: (x, 2)
114- # - or an iterable of Dimension, order ie: ((x, 2), (y, 2))
115- # - any combination of the above ie: ((x, 2), y, x, (z, 3))
170+ raise ValueError ("`expr` must be a `Differentiable` type object" ) from e
171+ return expr
172+
173+ @staticmethod
174+ def _validate_dims (dims ):
175+ """
176+ Validate `dims`. It can be:
177+ - a single Dimension ie: x
178+ - an iterable of Dimensions ie: (x, y)
179+ - a single tuple of Dimension and order ie: (x, 2)
180+ - or an iterable of Dimension, order ie: ((x, 2), (y, 2))
181+ - any combination of the above ie: ((x, 2), y, x, (z, 3))
182+ """
116183 if len (dims ) == 0 :
117184 raise ValueError ('Expected Dimension w.r.t. which to differentiate' )
118185 elif len (dims ) == 1 and isinstance (dims [0 ], Iterable ) and len (dims [0 ]) != 2 :
@@ -121,9 +188,16 @@ def __new__(cls, expr, *dims, **kwargs):
121188 elif len (dims ) == 2 and not isinstance (dims [1 ], Iterable ) and is_integer (dims [1 ]):
122189 # special case of single dimension and order
123190 dims = (dims , )
191+ return dims
124192
125- # Use `deriv_order` if specified
126- deriv_order = kwargs .get ('deriv_order' , (1 ,)* len (dims ))
193+ @staticmethod
194+ def _validate_deriv_order (deriv_order , dims ):
195+ """
196+ If provided `deriv_order` must correspond to the provided dimensions.
197+ Requires dims to validate or construct the default.
198+ """
199+ if deriv_order is None :
200+ deriv_order = (1 ,)* len (dims )
127201 if not isinstance (deriv_order , Iterable ):
128202 deriv_order = as_tuple (deriv_order )
129203 if len (deriv_order ) != len (dims ):
@@ -135,10 +209,15 @@ def __new__(cls, expr, *dims, **kwargs):
135209 'Invalid type in `deriv_order`, all elements must be non-negative'
136210 'Python `int`s'
137211 )
212+ return deriv_order
138213
139- # Count the number of derivatives for each dimension
214+ @staticmethod
215+ def _count_derivatives (deriv_order , dims ):
216+ """
217+ Count the number of derivatives for each dimension.
218+ """
140219 dcounter = defaultdict (int )
141- for d , o in zip (dims , deriv_order ):
220+ for d , o in zip (dims , deriv_order , strict = True ):
142221 if isinstance (d , Iterable ):
143222 if not is_integer (d [1 ]) or d [1 ] < 0 :
144223 raise TypeError (
@@ -149,25 +228,27 @@ def __new__(cls, expr, *dims, **kwargs):
149228 dcounter [d [0 ]] += d [1 ]
150229 else :
151230 dcounter [d ] += o
231+ return dcounter
152232
153- # It's possible that the expr is a `sympy.Number` at this point, which
154- # has derivative 0, unless we're taking a 0th derivative.
155- if isinstance (expr , sympy .Number ):
156- if any (dcounter .values ()):
157- return 0
158- else :
159- return expr
160-
161- # Use `fd_order` if specified
162- fd_order = kwargs .get ('fd_order' )
233+ @staticmethod
234+ def _validate_fd_order (fd_order , expr , dims , dcounter ):
235+ """
236+ If provided `fd_order` must correspond to the provided dimensions.
237+ Required `expr`, `dims` and the derivative counter to validate.
238+ If not provided the maximum supported order will be used.
239+ """
163240 if fd_order is not None :
164- # If `fd_order` is specified collect these together
241+ # If `fd_order` is specified validate
165242 fcounter = defaultdict (int )
166- for d , o in zip (dims , as_tuple (fd_order )):
243+ # First create a dictionary mapping variable wrt which to differentiate
244+ # to the `fd_order`
245+ for d , o in zip (dims , as_tuple (fd_order ), strict = True ):
167246 if isinstance (d , Iterable ):
168247 fcounter [d [0 ]] += o
169248 else :
170249 fcounter [d ] += o
250+ # Second validate that the `fd_order` is supported by the space or
251+ # time order
171252 for d , o in fcounter .items ():
172253 if getattr (d , 'is_Time' , False ):
173254 order = expr .time_order
@@ -189,44 +270,7 @@ def __new__(cls, expr, *dims, **kwargs):
189270 else expr .space_order
190271 for d in dcounter .keys ()
191272 )
192-
193- # SymPy expects the list of variables w.r.t. which we differentiate to be a list
194- # of 2-tuples: `(s, count)` where:
195- # - `s` is the entity to diff w.r.t. and
196- # - `count` is the order of the derivative
197- derivatives = [sympy .Tuple (d , o ) for d , o in dcounter .items ()]
198-
199- # Construct the actual Derivative object
200- obj = Differentiable .__new__ (cls , expr , * derivatives )
201- obj ._dims = tuple (dcounter .keys ())
202-
203- obj ._fd_order = DimensionTuple (
204- * as_tuple (fd_order ),
205- getters = obj ._dims
206- )
207- obj ._deriv_order = DimensionTuple (
208- * as_tuple (dcounter .values ()),
209- getters = obj ._dims
210- )
211- obj ._side = kwargs .get ("side" )
212- obj ._transpose = kwargs .get ("transpose" , direct )
213- obj ._method = kwargs .get ("method" , 'FD' )
214- obj ._weights = cls ._process_weights (** kwargs )
215-
216- ppsubs = kwargs .get ("subs" , kwargs .get ("_ppsubs" , []))
217- processed = []
218- if ppsubs :
219- for i in ppsubs :
220- try :
221- processed .append (frozendict (i ))
222- except AttributeError :
223- # E.g. `i` is a Transform object
224- processed .append (i )
225- obj ._ppsubs = tuple (processed )
226-
227- obj ._x0 = cls ._process_x0 (obj ._dims , ** kwargs )
228-
229- return obj
273+ return fd_order
230274
231275 @classmethod
232276 def _process_x0 (cls , dims , ** kwargs ):
@@ -274,7 +318,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, **kwargs):
274318 assert self .ndims == 1
275319 _fd_order = {self .dims [0 ]: fd_order }
276320 except AttributeError :
277- raise TypeError ("fd_order incompatible with dimensions" )
321+ raise TypeError ("fd_order incompatible with dimensions" ) from None
278322
279323 if isinstance (self .expr , Derivative ):
280324 # In case this was called on a perfect cross-derivative `u.dxdy`
@@ -538,48 +582,60 @@ def _eval_fd(self, expr, **kwargs):
538582 return res
539583
540584 def _eval_expand_nest (self , ** hints ):
541- ''' Expands nested derivatives
585+ """
586+ Expands nested derivatives
542587 `Derivative(Derivative(f(x), (x, b)), (x, a))
543588 --> Derivative(f(x), (x, a+b))`
544589 `Derivative(Derivative(f(x), (y, b)), (x, a))
545590 --> Derivative(f(x), (x, a), (y, b))`
546591 Note that this is not always a valid expansion depending on the kwargs
547592 used to construct the derivative.
548- '''
593+ """
549594 if not isinstance (self .expr , self .__class__ ):
550595 return self
551596
597+ # This is necessary as tools.abc.Reconstructable._rebuild will copy
598+ # all kwargs from the self object. Need to enssure that the nest is not
599+ # actually expanded if derivatives are incompatible.
600+ # The nested derivative is evaluated by:
601+ # 1. Chaining together the variables with which to differentiate wrt
552602 new_expr = self .expr .args [0 ]
553603 new_dims = [
554604 (d , ii )
555605 for d , ii in zip (
556606 chain (self .dims , self .expr .dims ),
557- chain (self .deriv_order , self .expr .deriv_order )
607+ chain (self .deriv_order , self .expr .deriv_order ),
608+ strict = True
558609 )
559610 ]
560- # This is necessary as tools.abc.Reconstructable._rebuild will copy
561- # all kwargs from the self object
562- # TODO: This dictionary merge needs to be a lot better
563- # EG: Don't actually expand if derivatives are incompatible
611+
612+ # 2. Count the number of derivatives to take wrt each variable as well as
613+ # the finite difference order to use by iterating over the chained lists of
614+ # variables.
564615 new_deriv_order = tuple (chain (self .deriv_order , self .expr .deriv_order ))
565- # The ` fd_order` may need to be reduced to construct the nested derivative
616+ new_fd_order = tuple ( chain ( self . fd_order , self . expr . fd_order ))
566617 dcounter = defaultdict (int )
567618 fcounter = defaultdict (int )
568- new_fd_order = tuple (chain (self .fd_order , self .expr .fd_order ))
569- for d , do , fo in zip (new_dims , new_deriv_order , new_fd_order ):
619+ for d , do , fo in zip (new_dims , new_deriv_order , new_fd_order , strict = True ):
570620 if isinstance (d , Iterable ):
571621 dcounter [d [0 ]] += d [1 ]
572622 fcounter [d [0 ]] += fo
573623 else :
574624 dcounter [d ] += do
575625 fcounter [d ] += fo
576- for (d , do ), (_ , fo ) in zip (dcounter .items (), fcounter .items ()):
626+
627+ # 3. Validate that the number of derivatives taken and the `fd_order` are
628+ # smaller than or equal to the corresponding space or time order that the
629+ # function supports.
630+ for (d , do ), (_ , fo ) in zip (dcounter .items (), fcounter .items (), strict = True ):
577631 if getattr (d , 'is_Time' , False ):
578632 dim_name = 'time'
579633 order = self .expr .time_order
580634 else :
581635 dim_name = 'space'
582636 order = self .expr .space_order
637+ # The `fd_order` may need to be reduced to construct the nested derivative
638+ # in this case we only emit a warning
583639 if fo > order :
584640 if do > order :
585641 raise ValueError (
@@ -593,45 +649,50 @@ def _eval_expand_nest(self, **hints):
593649 f'fd_order={ order } for the { d } dimension.'
594650 )
595651 fcounter [d ] = order
652+
653+ # 4. Finally, construct the new derivative object with the updated counts
654+ # and kwargs.
596655 new_kwargs = {
597656 'deriv_order' : tuple (dcounter .values ()),
598657 'fd_order' : tuple (fcounter .values ())
599658 }
600659 return self .func (new_expr , * dcounter .items (), ** new_kwargs )
601660
602661 def _eval_expand_mul (self , ** hints ):
603- ''' Expands products, moving independent terms outside the derivative
662+ """
663+ Expands products, moving independent terms outside the derivative
604664 `Derivative(C·f(x)·g(c, y), x)
605665 --> C·g(y)·Derivative(f(x), x)`
606- '''
666+ """
607667 if self .expr .is_Mul :
608668 ind , dep = self .expr .as_independent (* self .dims , as_Add = False )
609- return ind * self .func (dep , * self . args [ 1 :] )
669+ return ind * self .func (dep )
610670 else :
611671 return self
612672
613673 def _eval_expand_add (self , ** hints ):
614- ''' Expands sums, using linearity of derivative
674+ """
675+ Expands sums, using linearity of derivative
615676 `Derivative(f(x) + g(x), x)
616677 --> Derivative(f(x), x) + Derivative(g(x), x)`
617- '''
678+ """
618679 if self .expr .is_Add :
619680 ind , dep = self .expr .as_independent (* self .dims , as_Add = True )
620681 if dep .is_Add :
621682 return Add (* [self .func (s , * self .args [1 :]) for s in dep .args ])
622683 else :
623- return self .func (dep , * self . args [ 1 :] )
684+ return self .func (dep )
624685 else :
625686 return self
626687
627688 def _eval_expand_product_rule (self , ** hints ):
628- ''' Expands products, of functions of the dependent variable
689+ """
690+ Expands products, of functions of the dependent variable
629691 `Derivative(f(x)·g(x), x)
630692 --> Derivative(f(x), x)·g(x) + f(x)·Derivative(g(x), x)`
631693 This is only implemented for first derivatives with an arbitrary number
632694 of multiplicands and second derivatives with two multiplicands. The
633695 resultant expression for higher derivatives and mixed derivatives is much
634696 more difficult to implement.
635- '''
636- # if self.expr.is_Mul:
697+ """
637698 raise NotImplementedError ('Product rule expansion has not been written' )
0 commit comments