11import cgen as c
22import numpy as np
3+ from functools import cached_property
34
45from devito .passes .iet .engine import iet_pass
56from devito .ir .iet import (Transformer , MapNodes , Iteration , BlankLine ,
6- FindNodes , Call , CallableBody )
7- from devito .symbolics import Byref , Macro
7+ DummyExpr , CallableBody , List , Call , Callable ,
8+ FindNodes )
9+ from devito .symbolics import Byref , Macro , FieldFromPointer
10+ from devito .types import Symbol , Scalar
811from devito .types .basic import DataSymbol
9- from devito .petsc .types import (PetscMPIInt , PetscErrorCode , Initialize ,
10- Finalize , ArgvSymbol )
12+ from devito .tools import frozendict
13+ from devito .petsc .types import (PetscMPIInt , PetscErrorCode , MultipleFieldData ,
14+ PointerIS , Mat , LocalVec , GlobalVec , CallbackMat , SNES ,
15+ DummyArg , PetscInt , PointerDM , PointerMat , MatReuse ,
16+ CallbackPointerIS , CallbackPointerDM , JacobianStruct ,
17+ SubMatrixStruct , Initialize , Finalize , ArgvSymbol )
1118from devito .petsc .types .macros import petsc_func_begin_user
1219from devito .petsc .iet .nodes import PetscMetaData
1320from devito .petsc .utils import core_metadata
14- from devito .petsc .iet .routines import (CallbackBuilder , BaseObjectBuilder , BaseSetup ,
15- Solver , TimeDependent , NonTimeDependent )
21+ from devito .petsc .iet .routines import (CBBuilder , CCBBuilder , BaseObjectBuilder ,
22+ CoupledObjectBuilder , BaseSetup , CoupledSetup ,
23+ Solver , CoupledSolver , TimeDependent ,
24+ NonTimeDependent )
1625from devito .petsc .iet .utils import petsc_call , petsc_call_mpi
1726
1827
@@ -26,7 +35,6 @@ def lower_petsc(iet, **kwargs):
2635 return iet , {}
2736
2837 metadata = core_metadata ()
29-
3038 data = FindNodes (PetscMetaData ).visit (iet )
3139
3240 if any (filter (lambda i : isinstance (i .expr .rhs , Initialize ), data )):
@@ -35,10 +43,10 @@ def lower_petsc(iet, **kwargs):
3543 if any (filter (lambda i : isinstance (i .expr .rhs , Finalize ), data )):
3644 return finalize (iet ), metadata
3745
38- targets = [ i .expr .rhs .target for (i ,) in injectsolve_mapper .values ()]
39-
40- # Assumption is that all targets have the same grid so can use any target here
41- objs = build_core_objects ( targets [ - 1 ], ** kwargs )
46+ unique_grids = { i .expr .rhs .grid for (i ,) in injectsolve_mapper .values ()}
47+ # Assumption is that all solves are on the same grid
48+ if len ( unique_grids ) > 1 :
49+ raise ValueError ( "All PETScSolves must use the same Grid, but multiple found." )
4250
4351 # Create core PETSc calls (not specific to each PETScSolve)
4452 core = make_core_petsc_calls (objs , ** kwargs )
@@ -54,17 +62,18 @@ def lower_petsc(iet, **kwargs):
5462 setup .extend (builder .solversetup .calls )
5563
5664 # Transform the spatial iteration loop with the calls to execute the solver
57- subs .update (builder .solve .mapper )
65+ subs .update ({ builder .solve .spatial_body : builder . solve . calls } )
5866
5967 efuncs .update (builder .cbbuilder .efuncs )
6068
69+ populate_matrix_context (efuncs , objs )
70+
6171 iet = Transformer (subs ).visit (iet )
6272
6373 body = core + tuple (setup ) + (BlankLine ,) + iet .body .body
6474 body = iet .body ._rebuild (body = body )
6575 iet = iet ._rebuild (body = body )
6676 metadata .update ({'efuncs' : tuple (efuncs .values ())})
67-
6877 return iet , metadata
6978
7079
@@ -100,56 +109,140 @@ def make_core_petsc_calls(objs, **kwargs):
100109 return call_mpi , BlankLine
101110
102111
103- def build_core_objects (target , ** kwargs ):
104- communicator = 'PETSC_COMM_WORLD'
105-
106- return {
107- 'size' : PetscMPIInt (name = 'size' ),
108- 'comm' : communicator ,
109- 'err' : PetscErrorCode (name = 'err' ),
110- 'grid' : target .grid
111- }
112-
113-
114112class Builder :
115113 """
116114 This class is designed to support future extensions, enabling
117115 different combinations of solver types, preconditioning methods,
118116 and other functionalities as needed.
119-
120117 The class will be extended to accommodate different solver types by
121118 returning subclasses of the objects initialised in __init__,
122119 depending on the properties of `injectsolve`.
123120 """
124121 def __init__ (self , injectsolve , objs , iters , ** kwargs ):
122+ self .injectsolve = injectsolve
123+ self .objs = objs
124+ self .iters = iters
125+ self .kwargs = kwargs
126+ self .coupled = isinstance (injectsolve .expr .rhs .fielddata , MultipleFieldData )
127+ self .args = {
128+ 'injectsolve' : self .injectsolve ,
129+ 'objs' : self .objs ,
130+ 'iters' : self .iters ,
131+ ** self .kwargs
132+ }
133+ self .args ['solver_objs' ] = self .objbuilder .solver_objs
134+ self .args ['timedep' ] = self .timedep
135+ self .args ['cbbuilder' ] = self .cbbuilder
136+
137+ @cached_property
138+ def objbuilder (self ):
139+ return (
140+ CoupledObjectBuilder (** self .args )
141+ if self .coupled else
142+ BaseObjectBuilder (** self .args )
143+ )
125144
126- # Determine the time dependency class
127- time_mapper = injectsolve .expr .rhs .time_mapper
128- timedep = TimeDependent if time_mapper else NonTimeDependent
129- self .timedep = timedep (injectsolve , iters , ** kwargs )
145+ @cached_property
146+ def timedep (self ):
147+ time_mapper = self .injectsolve .expr .rhs .time_mapper
148+ timedep_class = TimeDependent if time_mapper else NonTimeDependent
149+ return timedep_class (** self .args )
130150
131- # Objects
132- self . objbuilder = BaseObjectBuilder ( injectsolve , ** kwargs )
133- self .solver_objs = self .objbuilder . solver_objs
151+ @ cached_property
152+ def cbbuilder ( self ):
153+ return CCBBuilder ( ** self .args ) if self .coupled else CBBuilder ( ** self . args )
134154
135- # Callbacks
136- self .cbbuilder = CallbackBuilder (
137- injectsolve , objs , self .solver_objs , timedep = self .timedep ,
138- ** kwargs
139- )
155+ @cached_property
156+ def solversetup (self ):
157+ return CoupledSetup (** self .args ) if self .coupled else BaseSetup (** self .args )
140158
141- # Solver setup
142- self .solversetup = BaseSetup (
143- self .solver_objs , objs , injectsolve , self .cbbuilder
144- )
159+ @cached_property
160+ def solve (self ):
161+ return CoupledSolver (** self .args ) if self .coupled else Solver (** self .args )
145162
146- # Execute the solver
147- self .solve = Solver (
148- self .solver_objs , objs , injectsolve , iters ,
149- self .cbbuilder , timedep = self .timedep
150- )
163+
164+ def populate_matrix_context (efuncs , objs ):
165+ if not objs ['dummyefunc' ] in efuncs .values ():
166+ return
167+
168+ subdms_expr = DummyExpr (
169+ FieldFromPointer (objs ['Subdms' ]._C_symbol , objs ['ljacctx' ]),
170+ objs ['Subdms' ]._C_symbol
171+ )
172+ fields_expr = DummyExpr (
173+ FieldFromPointer (objs ['Fields' ]._C_symbol , objs ['ljacctx' ]),
174+ objs ['Fields' ]._C_symbol
175+ )
176+ body = CallableBody (
177+ List (body = [subdms_expr , fields_expr ]),
178+ init = (objs ['begin_user' ],),
179+ retstmt = tuple ([Call ('PetscFunctionReturn' , arguments = [0 ])])
180+ )
181+ name = 'PopulateMatContext'
182+ efuncs [name ] = Callable (
183+ name , body , objs ['err' ],
184+ parameters = [objs ['ljacctx' ], objs ['Subdms' ], objs ['Fields' ]]
185+ )
151186
152187
153- # Move these to types folder
188+ # TODO: Devito MPI + PETSc testing
189+ # if kwargs['options']['mpi'] -> communicator = grid.distributor._obj_comm
190+ communicator = 'PETSC_COMM_WORLD'
191+ subdms = PointerDM (name = 'subdms' )
192+ fields = PointerIS (name = 'fields' )
193+ submats = PointerMat (name = 'submats' )
194+ rows = PointerIS (name = 'rows' )
195+ cols = PointerIS (name = 'cols' )
196+
197+
198+ # A static dict containing shared symbols and objects that are not
199+ # unique to each PETScSolve.
200+ # Many of these objects are used as arguments in callback functions to make
201+ # the C code cleaner and more modular. This is also a step toward leveraging
202+ # Devito's `reuse_efuncs` functionality, allowing reuse of efuncs when
203+ # they are semantically identical.
204+ objs = frozendict ({
205+ 'size' : PetscMPIInt (name = 'size' ),
206+ 'comm' : communicator ,
207+ 'err' : PetscErrorCode (name = 'err' ),
208+ 'block' : CallbackMat ('block' ),
209+ 'submat_arr' : PointerMat (name = 'submat_arr' ),
210+ 'subblockrows' : PetscInt ('subblockrows' ),
211+ 'subblockcols' : PetscInt ('subblockcols' ),
212+ 'rowidx' : PetscInt ('rowidx' ),
213+ 'colidx' : PetscInt ('colidx' ),
214+ 'J' : Mat ('J' ),
215+ 'X' : GlobalVec ('X' ),
216+ 'xloc' : LocalVec ('xloc' ),
217+ 'Y' : GlobalVec ('Y' ),
218+ 'yloc' : LocalVec ('yloc' ),
219+ 'F' : GlobalVec ('F' ),
220+ 'floc' : LocalVec ('floc' ),
221+ 'B' : GlobalVec ('B' ),
222+ 'nfields' : PetscInt ('nfields' ),
223+ 'irow' : PointerIS (name = 'irow' ),
224+ 'icol' : PointerIS (name = 'icol' ),
225+ 'nsubmats' : Scalar ('nsubmats' , dtype = np .int32 ),
226+ 'matreuse' : MatReuse ('scall' ),
227+ 'snes' : SNES ('snes' ),
228+ 'rows' : rows ,
229+ 'cols' : cols ,
230+ 'Subdms' : subdms ,
231+ 'LocalSubdms' : CallbackPointerDM (name = 'subdms' ),
232+ 'Fields' : fields ,
233+ 'LocalFields' : CallbackPointerIS (name = 'fields' ),
234+ 'Submats' : submats ,
235+ 'ljacctx' : JacobianStruct (
236+ fields = [subdms , fields , submats ], modifier = ' *'
237+ ),
238+ 'subctx' : SubMatrixStruct (fields = [rows , cols ]),
239+ 'Null' : Macro ('NULL' ),
240+ 'dummyctx' : Symbol ('lctx' ),
241+ 'dummyptr' : DummyArg ('dummy' ),
242+ 'dummyefunc' : Symbol ('dummyefunc' ),
243+ 'dof' : PetscInt ('dof' ),
244+ 'begin_user' : c .Line ('PetscFunctionBeginUser;' ),
245+ })
246+
247+ # Move to macros file?
154248Null = Macro ('NULL' )
155- void = 'void'
0 commit comments