11from collections import OrderedDict
22from functools import cached_property
3+ import math
34
45from devito .ir .iet import (Call , FindSymbols , List , Uxreplace , CallableBody ,
56 Dereference , DummyExpr , BlankLine , Callable , FindNodes ,
@@ -1111,8 +1112,19 @@ def _setup(self):
11111112 global_x = petsc_call ('DMCreateGlobalVector' ,
11121113 [dmda , Byref (sobjs ['xglobal' ])])
11131114
1114- local_x = petsc_call ('DMCreateLocalVector' ,
1115- [dmda , Byref (sobjs ['xlocal' ])])
1115+ target = self .fielddata .target
1116+ field_from_ptr = FieldFromPointer (
1117+ target .function ._C_field_data , target .function ._C_symbol
1118+ )
1119+
1120+ local_size = math .prod (
1121+ v for v , dim in zip (target .shape_allocated , target .dimensions ) if dim .is_Space
1122+ )
1123+ local_x = petsc_call ('VecCreateMPIWithArray' ,
1124+ ['PETSC_COMM_WORLD' , 1 , local_size , 'PETSC_DECIDE' ,
1125+ field_from_ptr , Byref (sobjs ['xlocal' ])])
1126+
1127+ # TODO: potentially also need to set the DM and local/global map to xlocal
11161128
11171129 get_local_size = petsc_call ('VecGetSize' ,
11181130 [sobjs ['xlocal' ], Byref (sobjs ['localsize' ])])
@@ -1247,11 +1259,87 @@ class CoupledSetup(BaseSetup):
12471259 def snes_ctx (self ):
12481260 return Byref (self .solver_objs ['jacctx' ])
12491261
1250- def _extend_setup (self ):
1262+ def _setup (self ):
1263+ # TODO: minimise code duplication with superclass
12511264 objs = self .objs
12521265 sobjs = self .solver_objs
12531266
12541267 dmda = sobjs ['dmda' ]
1268+
1269+ solver_params = self .injectsolve .expr .rhs .solver_parameters
1270+
1271+ snes_create = petsc_call ('SNESCreate' , [objs ['comm' ], Byref (sobjs ['snes' ])])
1272+
1273+ snes_set_dm = petsc_call ('SNESSetDM' , [sobjs ['snes' ], dmda ])
1274+
1275+ create_matrix = petsc_call ('DMCreateMatrix' , [dmda , Byref (sobjs ['Jac' ])])
1276+
1277+ # NOTE: Assuming all solves are linear for now
1278+ snes_set_type = petsc_call ('SNESSetType' , [sobjs ['snes' ], 'SNESKSPONLY' ])
1279+
1280+ snes_set_jac = petsc_call (
1281+ 'SNESSetJacobian' , [sobjs ['snes' ], sobjs ['Jac' ],
1282+ sobjs ['Jac' ], 'MatMFFDComputeJacobian' , objs ['Null' ]]
1283+ )
1284+
1285+ global_x = petsc_call ('DMCreateGlobalVector' ,
1286+ [dmda , Byref (sobjs ['xglobal' ])])
1287+
1288+ local_x = petsc_call ('DMCreateLocalVector' , [dmda , Byref (sobjs ['xlocal' ])])
1289+
1290+ get_local_size = petsc_call ('VecGetSize' ,
1291+ [sobjs ['xlocal' ], Byref (sobjs ['localsize' ])])
1292+
1293+ global_b = petsc_call ('DMCreateGlobalVector' ,
1294+ [dmda , Byref (sobjs ['bglobal' ])])
1295+
1296+ snes_get_ksp = petsc_call ('SNESGetKSP' ,
1297+ [sobjs ['snes' ], Byref (sobjs ['ksp' ])])
1298+
1299+ ksp_set_tols = petsc_call (
1300+ 'KSPSetTolerances' , [sobjs ['ksp' ], solver_params ['ksp_rtol' ],
1301+ solver_params ['ksp_atol' ], solver_params ['ksp_divtol' ],
1302+ solver_params ['ksp_max_it' ]]
1303+ )
1304+
1305+ ksp_set_type = petsc_call (
1306+ 'KSPSetType' , [sobjs ['ksp' ], solver_mapper [solver_params ['ksp_type' ]]]
1307+ )
1308+
1309+ ksp_get_pc = petsc_call (
1310+ 'KSPGetPC' , [sobjs ['ksp' ], Byref (sobjs ['pc' ])]
1311+ )
1312+
1313+ # Even though the default will be jacobi, set to PCNONE for now
1314+ pc_set_type = petsc_call ('PCSetType' , [sobjs ['pc' ], 'PCNONE' ])
1315+
1316+ ksp_set_from_ops = petsc_call ('KSPSetFromOptions' , [sobjs ['ksp' ]])
1317+
1318+ matvec = self .cbbuilder .main_matvec_callback
1319+ matvec_operation = petsc_call (
1320+ 'MatShellSetOperation' ,
1321+ [sobjs ['Jac' ], 'MATOP_MULT' , MatShellSetOp (matvec .name , void , void )]
1322+ )
1323+ formfunc = self .cbbuilder .main_formfunc_callback
1324+ formfunc_operation = petsc_call (
1325+ 'SNESSetFunction' ,
1326+ [sobjs ['snes' ], objs ['Null' ], FormFunctionCallback (formfunc .name , void , void ),
1327+ self .snes_ctx ]
1328+ )
1329+
1330+ dmda_calls = self ._create_dmda_calls (dmda )
1331+
1332+ mainctx = sobjs ['userctx' ]
1333+
1334+ call_struct_callback = petsc_call (
1335+ self .cbbuilder .user_struct_callback .name , [Byref (mainctx )]
1336+ )
1337+
1338+ # TODO: maybe don't need to explictly set this
1339+ mat_set_dm = petsc_call ('MatSetDM' , [sobjs ['Jac' ], dmda ])
1340+
1341+ calls_set_app_ctx = petsc_call ('DMSetApplicationContext' , [dmda , Byref (mainctx )])
1342+
12551343 create_field_decomp = petsc_call (
12561344 'DMCreateFieldDecomposition' ,
12571345 [dmda , Byref (sobjs ['nfields' ]), objs ['Null' ], Byref (sobjs ['fields' ]),
@@ -1297,13 +1385,34 @@ def _extend_setup(self):
12971385 [sobjs [f'da{ t .name } ' ], Byref (sobjs [f'bglobal{ t .name } ' ])]
12981386 ) for t in targets ]
12991387
1300- return (
1388+ coupled_setup = dmda_calls + (
1389+ snes_create ,
1390+ snes_set_dm ,
1391+ create_matrix ,
1392+ snes_set_jac ,
1393+ snes_set_type ,
1394+ global_x ,
1395+ local_x ,
1396+ get_local_size ,
1397+ global_b ,
1398+ snes_get_ksp ,
1399+ ksp_set_tols ,
1400+ ksp_set_type ,
1401+ ksp_get_pc ,
1402+ pc_set_type ,
1403+ ksp_set_from_ops ,
1404+ matvec_operation ,
1405+ formfunc_operation ,
1406+ call_struct_callback ,
1407+ mat_set_dm ,
1408+ calls_set_app_ctx ,
13011409 create_field_decomp ,
13021410 matop_create_submats_op ,
13031411 call_coupled_struct_callback ,
13041412 shell_set_ctx ,
1305- create_submats
1306- ) + tuple (deref_dms ) + tuple (xglobals ) + tuple (bglobals )
1413+ create_submats ) + \
1414+ tuple (deref_dms ) + tuple (xglobals ) + tuple (bglobals )
1415+ return coupled_setup
13071416
13081417
13091418class Solver :
@@ -1333,7 +1442,7 @@ def _execute_solve(self):
13331442
13341443 rhs_call = petsc_call (rhs_callback .name , [sobjs ['dmda' ], sobjs ['bglobal' ]])
13351444
1336- vec_replace_array = self .timedep .replace_array (target )
1445+ vec_place_array = self .timedep .place_array (target )
13371446
13381447 if self .cbbuilder .initialguesses :
13391448 initguess = self .cbbuilder .initialguesses [0 ]
@@ -1358,7 +1467,7 @@ def _execute_solve(self):
13581467
13591468 run_solver_calls = (struct_assignment ,) + (
13601469 rhs_call ,
1361- ) + vec_replace_array + (
1470+ ) + vec_place_array + (
13621471 initguess_call ,
13631472 dm_local_to_global_x ,
13641473 snes_solve ,
@@ -1415,7 +1524,7 @@ def _execute_solve(self):
14151524 pre_solve += (
14161525 petsc_call (c .name , [dm , target_bglob ]),
14171526 petsc_call ('DMCreateLocalVector' , [dm , Byref (target_xloc )]),
1418- self .timedep .replace_array (t ),
1527+ self .timedep .place_array (t ),
14191528 petsc_call (
14201529 'DMLocalToGlobal' ,
14211530 [dm , target_xloc , insert_vals , target_xglob ]
@@ -1485,23 +1594,7 @@ def _origin_to_moddim_mapper(self, iters):
14851594 def uxreplace_time (self , body ):
14861595 return body
14871596
1488- def replace_array (self , target ):
1489- """
1490- VecReplaceArray() is a PETSc function that allows replacing the array
1491- of a `Vec` with a user provided array.
1492- https://petsc.org/release/manualpages/Vec/VecReplaceArray/
1493-
1494- This function is used to replace the array of the PETSc solution `Vec`
1495- with the array from the `Function` object representing the target.
1496-
1497- Examples
1498- --------
1499- >>> target
1500- f1(x, y)
1501- >>> call = replace_array(target)
1502- >>> print(call)
1503- PetscCall(VecReplaceArray(xlocal0,f1_vec->data));
1504- """
1597+ def place_array (self , target ):
15051598 sobjs = self .sobjs
15061599
15071600 field_from_ptr = FieldFromPointer (
@@ -1608,29 +1701,27 @@ def _origin_to_moddim_mapper(self, iters):
16081701 mapper [d ] = d
16091702 return mapper
16101703
1611- def replace_array (self , target ):
1704+ def place_array (self , target ):
16121705 """
16131706 In the case that the actual target is time-dependent e.g a `TimeFunction`,
16141707 a pointer to the first element in the array that will be updated during
1615- the time step is passed to VecReplaceArray ().
1708+ the time step is passed to VecPlaceArray ().
16161709
16171710 Examples
16181711 --------
16191712 >>> target
16201713 f1(time + dt, x, y)
1621- >>> calls = replace_array (target)
1714+ >>> calls = place_array (target)
16221715 >>> print(List(body=calls))
1623- PetscCall(VecGetSize(xlocal0,&(localsize0)));
16241716 float * f1_ptr0 = (time + 1)*localsize0 + (float*)(f1_vec->data);
1625- PetscCall(VecReplaceArray (xlocal0,f1_ptr0));
1717+ PetscCall(VecPlaceArray (xlocal0,f1_ptr0));
16261718
16271719 >>> target
16281720 f1(t + dt, x, y)
1629- >>> calls = replace_array (target)
1721+ >>> calls = place_array (target)
16301722 >>> print(List(body=calls))
1631- PetscCall(VecGetSize(xlocal0,&(localsize0)));
16321723 float * f1_ptr0 = t1*localsize0 + (float*)(f1_vec->data);
1633- PetscCall(VecReplaceArray (xlocal0,f1_ptr0));
1724+ PetscCall(VecPlaceArray (xlocal0,f1_ptr0));
16341725 """
16351726 sobjs = self .sobjs
16361727
@@ -1654,7 +1745,7 @@ def replace_array(self, target):
16541745 ),
16551746 petsc_call ('VecPlaceArray' , [xlocal , start_ptr ])
16561747 )
1657- return super ().replace_array (target )
1748+ return super ().place_array (target )
16581749
16591750 def assign_time_iters (self , struct ):
16601751 """
0 commit comments