22from collections .abc import Iterable
33from functools import cached_property
44from itertools import chain
5+ from warnings import warn
56
67import sympy
78
@@ -160,18 +161,24 @@ def __new__(cls, expr, *dims, **kwargs):
160161 # Use `fd_order` if specified
161162 fd_order = kwargs .get ('fd_order' )
162163 if fd_order is not None :
163- _fd_order_specified = True
164164 # If `fd_order` is specified collect these together
165165 fcounter = defaultdict (int )
166166 for d , o in zip (dims , as_tuple (fd_order )):
167- fcounter [d ] += o
168- for d , o in dcounter .items ():
169- order = expr .time_order if getattr (d , 'is_Time' , False ) else expr .space_order
167+ if isinstance (d , Iterable ):
168+ fcounter [d [0 ]] += o
169+ else :
170+ fcounter [d ] += o
171+ for d , o in fcounter .items ():
172+ if getattr (d , 'is_Time' , False ):
173+ order = expr .time_order
174+ else :
175+ order = expr .space_order
170176 if o > order :
171- raise ValueError (f'Function does not support { d } derivative of order { o } ' )
177+ raise ValueError (
178+ f'Function does not support { d } -derivative with `fd_order` { o } '
179+ )
172180 fd_order = fcounter .values ()
173181 else :
174- _fd_order_specified = False
175182 # Default finite difference orders depending on input dimension (.dt or .dx)
176183 fd_order = tuple ([
177184 expr .time_order
@@ -180,16 +187,16 @@ def __new__(cls, expr, *dims, **kwargs):
180187 for d in dcounter .keys ()
181188 ])
182189
183- # SymPy expects the list of variable w.r.t. which we differentiate to be a list
184- # of 2-tuple `(s, count)` where s is the entity to diff wrt and count is the order
185- # of the derivative
190+ # SymPy expects the list of variables w.r.t. which we differentiate to be a list
191+ # of 2-tuples: `(s, count)` where:
192+ # - `s` is the entity to diff w.r.t. and
193+ # - `count` is the order of the derivative
186194 derivatives = [sympy .Tuple (d , o ) for d , o in dcounter .items ()]
187195
188196 # Construct the actual Derivative object
189197 obj = Differentiable .__new__ (cls , expr , * derivatives )
190198 obj ._dims = tuple (dcounter .keys ())
191199
192- obj ._fd_order_specified = _fd_order_specified
193200 obj ._fd_order = DimensionTuple (
194201 * as_tuple (fd_order ),
195202 getters = obj ._dims
@@ -544,17 +551,49 @@ def _eval_expand_nest(self, **hints):
544551 for d , ii in zip (
545552 chain (self .dims , expr .dims ),
546553 chain (self .deriv_order , expr .deriv_order )
547- )]
554+ )
555+ ]
548556 # This is necessary as tools.abc.Reconstructable._rebuild will copy
549557 # all kwargs from the self object
550558 # TODO: This dictionary merge needs to be a lot better
551559 # EG: Don't actually expand if derivatives are incompatible
560+ new_deriv_order = tuple (chain (self .deriv_order , expr .deriv_order ))
561+ # The `fd_order` may need to be reduced to construct the nested derivative
562+ dcounter = defaultdict (int )
563+ fcounter = defaultdict (int )
564+ new_fd_order = tuple (chain (self .fd_order , expr .fd_order ))
565+ for d , do , fo in zip (new_dims , new_deriv_order , new_fd_order ):
566+ if isinstance (d , Iterable ):
567+ dcounter [d [0 ]] += d [1 ]
568+ fcounter [d [0 ]] += fo
569+ else :
570+ dcounter [d ] += do
571+ fcounter [d ] += fo
572+ for (d , do ), (_ , fo ) in zip (dcounter .items (), fcounter .items ()):
573+ if getattr (d , 'is_Time' , False ):
574+ dim_name = 'time'
575+ order = expr .time_order
576+ else :
577+ dim_name = 'space'
578+ order = expr .space_order
579+ if fo > order :
580+ if do > order :
581+ raise ValueError (
582+ f'Nested { do } -derivative constructed which is bigger '
583+ f'than the { dim_name } _order={ order } '
584+ )
585+ else :
586+ warn (
587+ f'Nested derivative constructed with fd_order={ fo } , '
588+ f'but { dim_name } _order={ order } . Adjusting '
589+ f'fd_order={ order } for the { d } dimension.'
590+ )
591+ fcounter [d ] = order
552592 new_kwargs = {
553- 'deriv_order' : tuple (chain (self .deriv_order , expr .deriv_order ))
593+ 'deriv_order' : tuple (dcounter .values ()),
594+ 'fd_order' : tuple (fcounter .values ())
554595 }
555- if self ._fd_order_specified or expr ._fd_order_specified :
556- new_kwargs .update ({'fd_order' : tuple (chain (self .fd_order , expr .fd_order ))})
557- return self .func (new_expr , * new_dims , ** new_kwargs )
596+ return self .func (new_expr , * dcounter .items (), ** new_kwargs )
558597 else :
559598 return self
560599
0 commit comments