33import numpy as np
44from sympy import S
55
6- from devito .finite_differences import IndexDerivative
6+ from devito .finite_differences import IndexDerivative , Weights
77from devito .ir import Backward , Forward , Interval , IterationSpace , Queue
88from devito .passes .clusters .misc import fuse
99from devito .symbolics import BasicWrapperMixin , reuse_if_untouched , uxreplace
@@ -91,17 +91,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):
9191
9292
9393@_core .register (Symbol )
94- @_core .register (Indexed )
9594@_core .register (BasicWrapperMixin )
9695def _ (expr , c , ispace , weights , reusables , mapper , ** kwargs ):
9796 return expr , []
9897
9998
99+ @_core .register (Indexed )
100+ def _ (expr , c , ispace , weights , reusables , mapper , ** kwargs ):
101+ if not isinstance (expr .function , Weights ):
102+ return expr , []
103+
104+ # Lower or reuse a previously lowered Weights array
105+ sregistry = kwargs ['sregistry' ]
106+ subs_user = kwargs ['subs' ]
107+
108+ w0 = expr .function
109+ k = tuple (w0 .weights )
110+ try :
111+ w = weights [k ]
112+ except KeyError :
113+ name = sregistry .make_name (prefix = 'w' )
114+ dtype = infer_dtype ([w0 .dtype , c .dtype ]) # At least np.float32
115+ initvalue = tuple (i .subs (subs_user ) for i in k )
116+ w = weights [k ] = w0 ._rebuild (name = name , dtype = dtype , initvalue = initvalue )
117+
118+ rebuilt = expr ._subs (w0 .indexed , w .indexed )
119+
120+ return rebuilt , []
121+
122+
100123@_core .register (IndexDerivative )
101124def _ (expr , c , ispace , weights , reusables , mapper , ** kwargs ):
102125 sregistry = kwargs ['sregistry' ]
103126 options = kwargs ['options' ]
104- subs_user = kwargs ['subs' ]
105127
106128 try :
107129 cbk0 = deriv_schedule_registry [options ['deriv-schedule' ]]
@@ -114,18 +136,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
114136
115137 # Create the concrete Weights array, or reuse an already existing one
116138 # if possible
117- name = sregistry .make_name (prefix = 'w' )
118- w0 = ideriv .weights .function
119- dtype = infer_dtype ([w0 .dtype , c .dtype ]) # At least np.float32
120- k = tuple (w0 .weights )
121- try :
122- w = weights [k ]
123- except KeyError :
124- initvalue = tuple (i .subs (subs_user ) for i in k )
125- w = weights [k ] = w0 ._rebuild (name = name , dtype = dtype , initvalue = initvalue )
139+ w , _ = _core (ideriv .weights , c , ispace , weights , reusables , mapper , ** kwargs )
126140
127141 # Replace the abstract Weights array with the concrete one
128- subs = {w0 . indexed : w .indexed }
142+ subs = {ideriv . weights . base : w .base }
129143 init = uxreplace (init , subs )
130144 ideriv = uxreplace (ideriv , subs )
131145
@@ -152,13 +166,13 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
152166 ispace1 = IterationSpace .union (ispace , ispace0 , relations = extra )
153167
154168 # The Symbol that will hold the result of the IndexDerivative computation
155- # NOTE: created before recurring so that we ultimately get a sound ordering
169+ # NOTE: created before recursing so that we ultimately get a sound ordering
156170 try :
157171 s = reusables .pop ()
158- assert np .can_cast (s .dtype , dtype )
172+ assert np .can_cast (s .dtype , w . dtype )
159173 except KeyError :
160174 name = sregistry .make_name (prefix = 'r' )
161- s = Symbol (name = name , dtype = dtype )
175+ s = Symbol (name = name , dtype = w . dtype )
162176
163177 # Go inside `expr` and recursively lower any nested IndexDerivatives
164178 expr , processed = _core (expr , c , ispace1 , weights , reusables , mapper , ** kwargs )
0 commit comments