Skip to content

Commit 4f3bd88

Browse files
authored
Use VecCreateMPIWithArray (#77)
1 parent bdade2c commit 4f3bd88

4 files changed

Lines changed: 131 additions & 38 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ def _gen_struct_decl(self, obj, masked=()):
214214
except TypeError:
215215
# E.g., `ctype` is of type `dtypes_lowering.CustomDtype`
216216
if isinstance(obj, LocalCompositeObject):
217-
# TODO: Potentially re-evaluate: Setting ctype to obj allows
217+
# TODO: re-evaluate: Setting ctype to obj allows
218218
# _gen_struct_decl to generate a cgen.Structure from a
219219
# LocalCompositeObject, where obj._C_ctype is a CustomDtype.
220220
# LocalCompositeObject has a __fields__ property,
221-
# which allows the subsequent code in this function to function
221+
# which allows the subsequent code in this function to work
222222
# correctly.
223223
ctype = obj
224224
else:

devito/petsc/iet/routines.py

Lines changed: 126 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import OrderedDict
22
from functools import cached_property
3+
import math
34

45
from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody,
56
Dereference, DummyExpr, BlankLine, Callable, FindNodes,
@@ -1111,8 +1112,19 @@ def _setup(self):
11111112
global_x = petsc_call('DMCreateGlobalVector',
11121113
[dmda, Byref(sobjs['xglobal'])])
11131114

1114-
local_x = petsc_call('DMCreateLocalVector',
1115-
[dmda, Byref(sobjs['xlocal'])])
1115+
target = self.fielddata.target
1116+
field_from_ptr = FieldFromPointer(
1117+
target.function._C_field_data, target.function._C_symbol
1118+
)
1119+
1120+
local_size = math.prod(
1121+
v for v, dim in zip(target.shape_allocated, target.dimensions) if dim.is_Space
1122+
)
1123+
local_x = petsc_call('VecCreateMPIWithArray',
1124+
['PETSC_COMM_WORLD', 1, local_size, 'PETSC_DECIDE',
1125+
field_from_ptr, Byref(sobjs['xlocal'])])
1126+
1127+
# TODO: potentially also need to set the DM and local/global map to xlocal
11161128

11171129
get_local_size = petsc_call('VecGetSize',
11181130
[sobjs['xlocal'], Byref(sobjs['localsize'])])
@@ -1247,11 +1259,87 @@ class CoupledSetup(BaseSetup):
12471259
def snes_ctx(self):
12481260
return Byref(self.solver_objs['jacctx'])
12491261

1250-
def _extend_setup(self):
1262+
def _setup(self):
1263+
# TODO: minimise code duplication with superclass
12511264
objs = self.objs
12521265
sobjs = self.solver_objs
12531266

12541267
dmda = sobjs['dmda']
1268+
1269+
solver_params = self.injectsolve.expr.rhs.solver_parameters
1270+
1271+
snes_create = petsc_call('SNESCreate', [objs['comm'], Byref(sobjs['snes'])])
1272+
1273+
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])
1274+
1275+
create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])])
1276+
1277+
# NOTE: Assuming all solves are linear for now
1278+
snes_set_type = petsc_call('SNESSetType', [sobjs['snes'], 'SNESKSPONLY'])
1279+
1280+
snes_set_jac = petsc_call(
1281+
'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'],
1282+
sobjs['Jac'], 'MatMFFDComputeJacobian', objs['Null']]
1283+
)
1284+
1285+
global_x = petsc_call('DMCreateGlobalVector',
1286+
[dmda, Byref(sobjs['xglobal'])])
1287+
1288+
local_x = petsc_call('DMCreateLocalVector', [dmda, Byref(sobjs['xlocal'])])
1289+
1290+
get_local_size = petsc_call('VecGetSize',
1291+
[sobjs['xlocal'], Byref(sobjs['localsize'])])
1292+
1293+
global_b = petsc_call('DMCreateGlobalVector',
1294+
[dmda, Byref(sobjs['bglobal'])])
1295+
1296+
snes_get_ksp = petsc_call('SNESGetKSP',
1297+
[sobjs['snes'], Byref(sobjs['ksp'])])
1298+
1299+
ksp_set_tols = petsc_call(
1300+
'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1301+
solver_params['ksp_atol'], solver_params['ksp_divtol'],
1302+
solver_params['ksp_max_it']]
1303+
)
1304+
1305+
ksp_set_type = petsc_call(
1306+
'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1307+
)
1308+
1309+
ksp_get_pc = petsc_call(
1310+
'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])]
1311+
)
1312+
1313+
# Even though the default will be jacobi, set to PCNONE for now
1314+
pc_set_type = petsc_call('PCSetType', [sobjs['pc'], 'PCNONE'])
1315+
1316+
ksp_set_from_ops = petsc_call('KSPSetFromOptions', [sobjs['ksp']])
1317+
1318+
matvec = self.cbbuilder.main_matvec_callback
1319+
matvec_operation = petsc_call(
1320+
'MatShellSetOperation',
1321+
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
1322+
)
1323+
formfunc = self.cbbuilder.main_formfunc_callback
1324+
formfunc_operation = petsc_call(
1325+
'SNESSetFunction',
1326+
[sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void),
1327+
self.snes_ctx]
1328+
)
1329+
1330+
dmda_calls = self._create_dmda_calls(dmda)
1331+
1332+
mainctx = sobjs['userctx']
1333+
1334+
call_struct_callback = petsc_call(
1335+
self.cbbuilder.user_struct_callback.name, [Byref(mainctx)]
1336+
)
1337+
1338+
# TODO: maybe don't need to explictly set this
1339+
mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda])
1340+
1341+
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
1342+
12551343
create_field_decomp = petsc_call(
12561344
'DMCreateFieldDecomposition',
12571345
[dmda, Byref(sobjs['nfields']), objs['Null'], Byref(sobjs['fields']),
@@ -1297,13 +1385,34 @@ def _extend_setup(self):
12971385
[sobjs[f'da{t.name}'], Byref(sobjs[f'bglobal{t.name}'])]
12981386
) for t in targets]
12991387

1300-
return (
1388+
coupled_setup = dmda_calls + (
1389+
snes_create,
1390+
snes_set_dm,
1391+
create_matrix,
1392+
snes_set_jac,
1393+
snes_set_type,
1394+
global_x,
1395+
local_x,
1396+
get_local_size,
1397+
global_b,
1398+
snes_get_ksp,
1399+
ksp_set_tols,
1400+
ksp_set_type,
1401+
ksp_get_pc,
1402+
pc_set_type,
1403+
ksp_set_from_ops,
1404+
matvec_operation,
1405+
formfunc_operation,
1406+
call_struct_callback,
1407+
mat_set_dm,
1408+
calls_set_app_ctx,
13011409
create_field_decomp,
13021410
matop_create_submats_op,
13031411
call_coupled_struct_callback,
13041412
shell_set_ctx,
1305-
create_submats
1306-
) + tuple(deref_dms) + tuple(xglobals) + tuple(bglobals)
1413+
create_submats) + \
1414+
tuple(deref_dms) + tuple(xglobals) + tuple(bglobals)
1415+
return coupled_setup
13071416

13081417

13091418
class Solver:
@@ -1333,7 +1442,7 @@ def _execute_solve(self):
13331442

13341443
rhs_call = petsc_call(rhs_callback.name, [sobjs['dmda'], sobjs['bglobal']])
13351444

1336-
vec_replace_array = self.timedep.replace_array(target)
1445+
vec_place_array = self.timedep.place_array(target)
13371446

13381447
if self.cbbuilder.initialguesses:
13391448
initguess = self.cbbuilder.initialguesses[0]
@@ -1358,7 +1467,7 @@ def _execute_solve(self):
13581467

13591468
run_solver_calls = (struct_assignment,) + (
13601469
rhs_call,
1361-
) + vec_replace_array + (
1470+
) + vec_place_array + (
13621471
initguess_call,
13631472
dm_local_to_global_x,
13641473
snes_solve,
@@ -1415,7 +1524,7 @@ def _execute_solve(self):
14151524
pre_solve += (
14161525
petsc_call(c.name, [dm, target_bglob]),
14171526
petsc_call('DMCreateLocalVector', [dm, Byref(target_xloc)]),
1418-
self.timedep.replace_array(t),
1527+
self.timedep.place_array(t),
14191528
petsc_call(
14201529
'DMLocalToGlobal',
14211530
[dm, target_xloc, insert_vals, target_xglob]
@@ -1485,23 +1594,7 @@ def _origin_to_moddim_mapper(self, iters):
14851594
def uxreplace_time(self, body):
14861595
return body
14871596

1488-
def replace_array(self, target):
1489-
"""
1490-
VecReplaceArray() is a PETSc function that allows replacing the array
1491-
of a `Vec` with a user provided array.
1492-
https://petsc.org/release/manualpages/Vec/VecReplaceArray/
1493-
1494-
This function is used to replace the array of the PETSc solution `Vec`
1495-
with the array from the `Function` object representing the target.
1496-
1497-
Examples
1498-
--------
1499-
>>> target
1500-
f1(x, y)
1501-
>>> call = replace_array(target)
1502-
>>> print(call)
1503-
PetscCall(VecReplaceArray(xlocal0,f1_vec->data));
1504-
"""
1597+
def place_array(self, target):
15051598
sobjs = self.sobjs
15061599

15071600
field_from_ptr = FieldFromPointer(
@@ -1608,29 +1701,27 @@ def _origin_to_moddim_mapper(self, iters):
16081701
mapper[d] = d
16091702
return mapper
16101703

1611-
def replace_array(self, target):
1704+
def place_array(self, target):
16121705
"""
16131706
In the case that the actual target is time-dependent e.g a `TimeFunction`,
16141707
a pointer to the first element in the array that will be updated during
1615-
the time step is passed to VecReplaceArray().
1708+
the time step is passed to VecPlaceArray().
16161709
16171710
Examples
16181711
--------
16191712
>>> target
16201713
f1(time + dt, x, y)
1621-
>>> calls = replace_array(target)
1714+
>>> calls = place_array(target)
16221715
>>> print(List(body=calls))
1623-
PetscCall(VecGetSize(xlocal0,&(localsize0)));
16241716
float * f1_ptr0 = (time + 1)*localsize0 + (float*)(f1_vec->data);
1625-
PetscCall(VecReplaceArray(xlocal0,f1_ptr0));
1717+
PetscCall(VecPlaceArray(xlocal0,f1_ptr0));
16261718
16271719
>>> target
16281720
f1(t + dt, x, y)
1629-
>>> calls = replace_array(target)
1721+
>>> calls = place_array(target)
16301722
>>> print(List(body=calls))
1631-
PetscCall(VecGetSize(xlocal0,&(localsize0)));
16321723
float * f1_ptr0 = t1*localsize0 + (float*)(f1_vec->data);
1633-
PetscCall(VecReplaceArray(xlocal0,f1_ptr0));
1724+
PetscCall(VecPlaceArray(xlocal0,f1_ptr0));
16341725
"""
16351726
sobjs = self.sobjs
16361727

@@ -1654,7 +1745,7 @@ def replace_array(self, target):
16541745
),
16551746
petsc_call('VecPlaceArray', [xlocal, start_ptr])
16561747
)
1657-
return super().replace_array(target)
1748+
return super().place_array(target)
16581749

16591750
def assign_time_iters(self, struct):
16601751
"""

devito/petsc/types/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def dofs(self):
3636
def _C_free(self):
3737
return petsc_call('DMDestroy', [Byref(self.function)])
3838

39-
# TODO: This is growing out of hand so switch to an enumeration or something?
39+
# TODO: Switch to an enumeration?
4040
@property
4141
def _C_free_priority(self):
4242
return 4

tests/test_symbolics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def test_extended_sympy_arithmetic():
319319
o = Object(name='o', dtype=c_void_p)
320320
bar = FieldFromPointer('bar', o)
321321
# TODO: Edit/fix/update according to PR #2513
322+
# The order changed due to adding the dtype property
323+
# to FieldFromPointer
322324
assert ccode(-1 + bar) == 'o->bar - 1'
323325

324326

0 commit comments

Comments
 (0)