Skip to content

Commit 0af998d

Browse files
committed
Coupled (#41)
* compiler/dsl: Add machinery to support coupled solvers
1 parent f4ee1b9 commit 0af998d

18 files changed

Lines changed: 2001 additions & 657 deletions

File tree

.github/workflows/pytest-petsc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
6565
- name: Test with pytest
6666
run: |
67-
${{ env.RUN_CMD }} pytest --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
67+
${{ env.RUN_CMD }} mpiexec -n 1 pytest --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
6868
6969
- name: Upload coverage to Codecov
7070
if: "!contains(matrix.name, 'docker')"

devito/ir/equations/algorithms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def concretize_subdims(exprs, **kwargs):
191191
"""
192192
sregistry = kwargs.get('sregistry')
193193

194-
mapper = {}
194+
# Update based on changes in #2509
195+
mapper = kwargs.get('concretize_mapper', {})
195196
rebuilt = {} # Rebuilt implicit dims etc which are shared between dimensions
196197

197198
_concretize_subdims(exprs, mapper, rebuilt, sregistry)

devito/ir/iet/nodes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ class Dereference(ExprStmt, Node):
10461046
The following cases are supported:
10471047
10481048
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
1049-
* `pointer` is an ArrayObject or CCompositeObject representing a pointer
1049+
* `pointer` is an ArrayObject or LocalCompositeObject representing a pointer
10501050
to a C struct, and `pointee` is a field in `pointer`.
10511051
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
10521052
`pointee` is a Symbol representing the dereferenced value.
@@ -1070,8 +1070,7 @@ def functions(self):
10701070
def expr_symbols(self):
10711071
ret = []
10721072
if self.pointer.is_Symbol:
1073-
assert (isinstance(self.pointer, LocalCompositeObject) or
1074-
issubclass(self.pointer._C_ctype, ctypes._Pointer)), \
1073+
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
10751074
"Scalar dereference must have a pointer ctype"
10761075
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
10771076
elif self.pointer.is_PointerArray or self.pointer.is_TempFunction:
@@ -1080,7 +1079,9 @@ def expr_symbols(self):
10801079
for i in self.pointee.symbolic_shape[1:]))
10811080
ret.extend(self.pointer.free_symbols)
10821081
else:
1083-
ret.extend([self.pointer.indexed, self.pointee._C_symbol])
1082+
assert (isinstance(self.pointer, LocalCompositeObject) or
1083+
issubclass(self.pointer._C_ctype, ctypes._Pointer))
1084+
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
10841085
return tuple(filter_ordered(ret))
10851086

10861087
@property

devito/ir/iet/visitors.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
c_restrict_void_p, sorted_priority)
2626
from devito.types.basic import AbstractFunction, AbstractSymbol, Basic
2727
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
28-
IndexedData, DeviceMap)
28+
IndexedData, DeviceMap, LocalCompositeObject)
2929

3030

3131
__all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols',
@@ -201,7 +201,7 @@ def _restrict_keyword(self):
201201

202202
def _gen_struct_decl(self, obj, masked=()):
203203
"""
204-
Convert ctypes.Struct -> cgen.Structure.
204+
Convert ctypes.Struct and LocalCompositeObject -> cgen.Structure.
205205
"""
206206
ctype = obj._C_ctype
207207
try:
@@ -213,7 +213,16 @@ def _gen_struct_decl(self, obj, masked=()):
213213
return None
214214
except TypeError:
215215
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
216-
return None
216+
if isinstance(obj, LocalCompositeObject):
217+
# TODO: Potentially re-evaluate: Setting ctype to obj allows
218+
# _gen_struct_decl to generate a cgen.Structure from a
219+
# LocalCompositeObject, where obj._C_ctype is a CustomDtype.
220+
# LocalCompositeObject has a __fields__ property,
221+
# which allows the subsequent code in this function to function
222+
# correctly.
223+
ctype = obj
224+
else:
225+
return None
217226

218227
try:
219228
return obj._C_typedecl
@@ -718,8 +727,11 @@ def _operator_typedecls(self, o, mode='all'):
718727
for i in o._func_table.values():
719728
if not i.local:
720729
continue
721-
typedecls.extend([self._gen_struct_decl(j) for j in i.root.parameters
722-
if xfilter(j)])
730+
typedecls.extend([
731+
self._gen_struct_decl(j)
732+
for j in FindSymbols().visit(i.root)
733+
if xfilter(j)
734+
])
723735
typedecls = filter_sorted(typedecls, key=lambda i: i.tpname)
724736

725737
return typedecls

devito/passes/iet/definitions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,10 @@ def _alloc_object_array_on_low_lat_mem(self, site, obj, storage):
256256
"""
257257
Allocate an Array of Objects in the low latency memory.
258258
"""
259+
frees = getattr(obj, '_C_free', None)
259260
decl = Definition(obj)
260261

261-
storage.update(obj, site, allocs=decl)
262+
storage.update(obj, site, allocs=decl, frees=frees)
262263

263264
def _alloc_pointed_array_on_high_bw_mem(self, site, obj, storage):
264265
"""
@@ -335,7 +336,7 @@ def _inject_definitions(self, iet, storage):
335336
frees.extend(as_list(cbody.frees) + flatten(v.frees))
336337
frees = sorted(frees, key=lambda x: min(
337338
(obj._C_free_priority for obj in FindSymbols().visit(x)
338-
if obj.is_LocalObject), default=float('inf')
339+
if obj.is_LocalType), default=float('inf')
339340
))
340341

341342
# maps/unmaps

devito/petsc/clusters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def petsc_lift(clusters):
2020
processed = []
2121
for c in clusters:
2222
if isinstance(c.exprs[0].rhs, LinearSolveExpr):
23-
ispace = c.ispace.lift(c.exprs[0].rhs.target.space_dimensions)
23+
ispace = c.ispace.lift(c.exprs[0].rhs.fielddata.space_dimensions)
2424
processed.append(c.rebuild(ispace=ispace))
2525
else:
2626
processed.append(c)

devito/petsc/iet/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class PETScCallable(FixedArgsCallable):
1414
pass
1515

1616

17-
class MatVecCallback(Callback):
17+
class MatShellSetOp(Callback):
1818
@property
1919
def callback_form(self):
2020
param_types_str = ', '.join([str(t) for t in self.param_types])

devito/petsc/iet/passes.py

Lines changed: 141 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
import cgen as c
22
import numpy as np
3+
from functools import cached_property
34

45
from devito.passes.iet.engine import iet_pass
56
from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine,
6-
FindNodes, Call, CallableBody)
7-
from devito.symbolics import Byref, Macro
7+
DummyExpr, CallableBody, List, Call, Callable,
8+
FindNodes)
9+
from devito.symbolics import Byref, Macro, FieldFromPointer
10+
from devito.types import Symbol, Scalar
811
from devito.types.basic import DataSymbol
9-
from devito.petsc.types import (PetscMPIInt, PetscErrorCode, Initialize,
10-
Finalize, ArgvSymbol)
12+
from devito.tools import frozendict
13+
from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData,
14+
PointerIS, Mat, LocalVec, GlobalVec, CallbackMat, SNES,
15+
DummyArg, PetscInt, PointerDM, PointerMat, MatReuse,
16+
CallbackPointerIS, CallbackPointerDM, JacobianStruct,
17+
SubMatrixStruct, Initialize, Finalize, ArgvSymbol)
1118
from devito.petsc.types.macros import petsc_func_begin_user
1219
from devito.petsc.iet.nodes import PetscMetaData
1320
from devito.petsc.utils import core_metadata
14-
from devito.petsc.iet.routines import (CallbackBuilder, BaseObjectBuilder, BaseSetup,
15-
Solver, TimeDependent, NonTimeDependent)
21+
from devito.petsc.iet.routines import (CBBuilder, CCBBuilder, BaseObjectBuilder,
22+
CoupledObjectBuilder, BaseSetup, CoupledSetup,
23+
Solver, CoupledSolver, TimeDependent,
24+
NonTimeDependent)
1625
from devito.petsc.iet.utils import petsc_call, petsc_call_mpi
1726

1827

@@ -26,7 +35,6 @@ def lower_petsc(iet, **kwargs):
2635
return iet, {}
2736

2837
metadata = core_metadata()
29-
3038
data = FindNodes(PetscMetaData).visit(iet)
3139

3240
if any(filter(lambda i: isinstance(i.expr.rhs, Initialize), data)):
@@ -35,10 +43,10 @@ def lower_petsc(iet, **kwargs):
3543
if any(filter(lambda i: isinstance(i.expr.rhs, Finalize), data)):
3644
return finalize(iet), metadata
3745

38-
targets = [i.expr.rhs.target for (i,) in injectsolve_mapper.values()]
39-
40-
# Assumption is that all targets have the same grid so can use any target here
41-
objs = build_core_objects(targets[-1], **kwargs)
46+
unique_grids = {i.expr.rhs.grid for (i,) in injectsolve_mapper.values()}
47+
# Assumption is that all solves are on the same grid
48+
if len(unique_grids) > 1:
49+
raise ValueError("All PETScSolves must use the same Grid, but multiple found.")
4250

4351
# Create core PETSc calls (not specific to each PETScSolve)
4452
core = make_core_petsc_calls(objs, **kwargs)
@@ -54,17 +62,18 @@ def lower_petsc(iet, **kwargs):
5462
setup.extend(builder.solversetup.calls)
5563

5664
# Transform the spatial iteration loop with the calls to execute the solver
57-
subs.update(builder.solve.mapper)
65+
subs.update({builder.solve.spatial_body: builder.solve.calls})
5866

5967
efuncs.update(builder.cbbuilder.efuncs)
6068

69+
populate_matrix_context(efuncs, objs)
70+
6171
iet = Transformer(subs).visit(iet)
6272

6373
body = core + tuple(setup) + (BlankLine,) + iet.body.body
6474
body = iet.body._rebuild(body=body)
6575
iet = iet._rebuild(body=body)
6676
metadata.update({'efuncs': tuple(efuncs.values())})
67-
6877
return iet, metadata
6978

7079

@@ -100,56 +109,140 @@ def make_core_petsc_calls(objs, **kwargs):
100109
return call_mpi, BlankLine
101110

102111

103-
def build_core_objects(target, **kwargs):
104-
communicator = 'PETSC_COMM_WORLD'
105-
106-
return {
107-
'size': PetscMPIInt(name='size'),
108-
'comm': communicator,
109-
'err': PetscErrorCode(name='err'),
110-
'grid': target.grid
111-
}
112-
113-
114112
class Builder:
115113
"""
116114
This class is designed to support future extensions, enabling
117115
different combinations of solver types, preconditioning methods,
118116
and other functionalities as needed.
119-
120117
The class will be extended to accommodate different solver types by
121118
returning subclasses of the objects initialised in __init__,
122119
depending on the properties of `injectsolve`.
123120
"""
124121
def __init__(self, injectsolve, objs, iters, **kwargs):
122+
self.injectsolve = injectsolve
123+
self.objs = objs
124+
self.iters = iters
125+
self.kwargs = kwargs
126+
self.coupled = isinstance(injectsolve.expr.rhs.fielddata, MultipleFieldData)
127+
self.args = {
128+
'injectsolve': self.injectsolve,
129+
'objs': self.objs,
130+
'iters': self.iters,
131+
**self.kwargs
132+
}
133+
self.args['solver_objs'] = self.objbuilder.solver_objs
134+
self.args['timedep'] = self.timedep
135+
self.args['cbbuilder'] = self.cbbuilder
136+
137+
@cached_property
138+
def objbuilder(self):
139+
return (
140+
CoupledObjectBuilder(**self.args)
141+
if self.coupled else
142+
BaseObjectBuilder(**self.args)
143+
)
125144

126-
# Determine the time dependency class
127-
time_mapper = injectsolve.expr.rhs.time_mapper
128-
timedep = TimeDependent if time_mapper else NonTimeDependent
129-
self.timedep = timedep(injectsolve, iters, **kwargs)
145+
@cached_property
146+
def timedep(self):
147+
time_mapper = self.injectsolve.expr.rhs.time_mapper
148+
timedep_class = TimeDependent if time_mapper else NonTimeDependent
149+
return timedep_class(**self.args)
130150

131-
# Objects
132-
self.objbuilder = BaseObjectBuilder(injectsolve, **kwargs)
133-
self.solver_objs = self.objbuilder.solver_objs
151+
@cached_property
152+
def cbbuilder(self):
153+
return CCBBuilder(**self.args) if self.coupled else CBBuilder(**self.args)
134154

135-
# Callbacks
136-
self.cbbuilder = CallbackBuilder(
137-
injectsolve, objs, self.solver_objs, timedep=self.timedep,
138-
**kwargs
139-
)
155+
@cached_property
156+
def solversetup(self):
157+
return CoupledSetup(**self.args) if self.coupled else BaseSetup(**self.args)
140158

141-
# Solver setup
142-
self.solversetup = BaseSetup(
143-
self.solver_objs, objs, injectsolve, self.cbbuilder
144-
)
159+
@cached_property
160+
def solve(self):
161+
return CoupledSolver(**self.args) if self.coupled else Solver(**self.args)
145162

146-
# Execute the solver
147-
self.solve = Solver(
148-
self.solver_objs, objs, injectsolve, iters,
149-
self.cbbuilder, timedep=self.timedep
150-
)
163+
164+
def populate_matrix_context(efuncs, objs):
165+
if not objs['dummyefunc'] in efuncs.values():
166+
return
167+
168+
subdms_expr = DummyExpr(
169+
FieldFromPointer(objs['Subdms']._C_symbol, objs['ljacctx']),
170+
objs['Subdms']._C_symbol
171+
)
172+
fields_expr = DummyExpr(
173+
FieldFromPointer(objs['Fields']._C_symbol, objs['ljacctx']),
174+
objs['Fields']._C_symbol
175+
)
176+
body = CallableBody(
177+
List(body=[subdms_expr, fields_expr]),
178+
init=(objs['begin_user'],),
179+
retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])])
180+
)
181+
name = 'PopulateMatContext'
182+
efuncs[name] = Callable(
183+
name, body, objs['err'],
184+
parameters=[objs['ljacctx'], objs['Subdms'], objs['Fields']]
185+
)
151186

152187

153-
# Move these to types folder
188+
# TODO: Devito MPI + PETSc testing
189+
# if kwargs['options']['mpi'] -> communicator = grid.distributor._obj_comm
190+
communicator = 'PETSC_COMM_WORLD'
191+
subdms = PointerDM(name='subdms')
192+
fields = PointerIS(name='fields')
193+
submats = PointerMat(name='submats')
194+
rows = PointerIS(name='rows')
195+
cols = PointerIS(name='cols')
196+
197+
198+
# A static dict containing shared symbols and objects that are not
199+
# unique to each PETScSolve.
200+
# Many of these objects are used as arguments in callback functions to make
201+
# the C code cleaner and more modular. This is also a step toward leveraging
202+
# Devito's `reuse_efuncs` functionality, allowing reuse of efuncs when
203+
# they are semantically identical.
204+
objs = frozendict({
205+
'size': PetscMPIInt(name='size'),
206+
'comm': communicator,
207+
'err': PetscErrorCode(name='err'),
208+
'block': CallbackMat('block'),
209+
'submat_arr': PointerMat(name='submat_arr'),
210+
'subblockrows': PetscInt('subblockrows'),
211+
'subblockcols': PetscInt('subblockcols'),
212+
'rowidx': PetscInt('rowidx'),
213+
'colidx': PetscInt('colidx'),
214+
'J': Mat('J'),
215+
'X': GlobalVec('X'),
216+
'xloc': LocalVec('xloc'),
217+
'Y': GlobalVec('Y'),
218+
'yloc': LocalVec('yloc'),
219+
'F': GlobalVec('F'),
220+
'floc': LocalVec('floc'),
221+
'B': GlobalVec('B'),
222+
'nfields': PetscInt('nfields'),
223+
'irow': PointerIS(name='irow'),
224+
'icol': PointerIS(name='icol'),
225+
'nsubmats': Scalar('nsubmats', dtype=np.int32),
226+
'matreuse': MatReuse('scall'),
227+
'snes': SNES('snes'),
228+
'rows': rows,
229+
'cols': cols,
230+
'Subdms': subdms,
231+
'LocalSubdms': CallbackPointerDM(name='subdms'),
232+
'Fields': fields,
233+
'LocalFields': CallbackPointerIS(name='fields'),
234+
'Submats': submats,
235+
'ljacctx': JacobianStruct(
236+
fields=[subdms, fields, submats], modifier=' *'
237+
),
238+
'subctx': SubMatrixStruct(fields=[rows, cols]),
239+
'Null': Macro('NULL'),
240+
'dummyctx': Symbol('lctx'),
241+
'dummyptr': DummyArg('dummy'),
242+
'dummyefunc': Symbol('dummyefunc'),
243+
'dof': PetscInt('dof'),
244+
'begin_user': c.Line('PetscFunctionBeginUser;'),
245+
})
246+
247+
# Move to macros file?
154248
Null = Macro('NULL')
155-
void = 'void'

0 commit comments

Comments
 (0)