1616 MatShellSetOp , PetscMetaData )
1717from devito .petsc .iet .utils import petsc_call , petsc_struct
1818from devito .petsc .utils import solver_mapper
19- from devito .petsc .types import (DM , Mat , LocalVec , GlobalVec , KSP , PC , SNES ,
19+ from devito .petsc .types import (DM , Mat , CallbackVec , Vec , KSP , PC , SNES ,
2020 PetscInt , StartPtr , PointerIS , PointerDM , VecScatter ,
2121 DMCast , JacobianStructCast , JacobianStruct ,
2222 SubMatrixStruct , CallbackDM )
@@ -448,8 +448,13 @@ def _create_form_rhs_body(self, body, fielddata):
448448 'VecRestoreArray' , [sobjs ['blocal' ], Byref (b_arr ._C_symbol )]
449449 )
450450
451+ dm_restore_local_bvec = petsc_call (
452+ 'DMRestoreLocalVector' , [dmda , Byref (sobjs ['blocal' ])]
453+ )
454+
451455 body = body ._rebuild (body = body .body + (
452- dm_local_to_global_begin , dm_local_to_global_end , vec_restore_array
456+ dm_local_to_global_begin , dm_local_to_global_end , vec_restore_array ,
457+ dm_restore_local_bvec
453458 ))
454459
455460 stacks = (
@@ -870,10 +875,10 @@ def _build(self):
870875 targets = self .fielddata .targets
871876 base_dict = {
872877 'Jac' : Mat (sreg .make_name (prefix = 'J' )),
873- 'xglobal' : GlobalVec (sreg .make_name (prefix = 'xglobal' )),
874- 'xlocal' : LocalVec (sreg .make_name (prefix = 'xlocal' )),
875- 'bglobal' : GlobalVec (sreg .make_name (prefix = 'bglobal' )),
876- 'blocal' : LocalVec (sreg .make_name (prefix = 'blocal' )),
878+ 'xglobal' : Vec (sreg .make_name (prefix = 'xglobal' )),
879+ 'xlocal' : Vec (sreg .make_name (prefix = 'xlocal' )),
880+ 'bglobal' : Vec (sreg .make_name (prefix = 'bglobal' )),
881+ 'blocal' : CallbackVec (sreg .make_name (prefix = 'blocal' )),
877882 'ksp' : KSP (sreg .make_name (prefix = 'ksp' )),
878883 'pc' : PC (sreg .make_name (prefix = 'pc' )),
879884 'snes' : SNES (sreg .make_name (prefix = 'snes' )),
@@ -939,9 +944,9 @@ def _extend_build(self, base_dict):
939944 name = f'{ key } ctx' ,
940945 fields = objs ['subctx' ].fields ,
941946 )
942- base_dict [f'{ key } X' ] = LocalVec (f'{ key } X' )
943- base_dict [f'{ key } Y' ] = LocalVec (f'{ key } Y' )
944- base_dict [f'{ key } F' ] = LocalVec (f'{ key } F' )
947+ base_dict [f'{ key } X' ] = CallbackVec (f'{ key } X' )
948+ base_dict [f'{ key } Y' ] = CallbackVec (f'{ key } Y' )
949+ base_dict [f'{ key } F' ] = CallbackVec (f'{ key } F' )
945950
946951 return base_dict
947952
@@ -953,22 +958,22 @@ def _target_dependent(self, base_dict):
953958 base_dict [f'{ name } _ptr' ] = StartPtr (
954959 sreg .make_name (prefix = f'{ name } _ptr' ), t .dtype
955960 )
956- base_dict [f'xlocal{ name } ' ] = LocalVec (
961+ base_dict [f'xlocal{ name } ' ] = CallbackVec (
957962 sreg .make_name (prefix = f'xlocal{ name } ' ), liveness = 'eager'
958963 )
959- base_dict [f'Fglobal{ name } ' ] = LocalVec (
964+ base_dict [f'Fglobal{ name } ' ] = CallbackVec (
960965 sreg .make_name (prefix = f'Fglobal{ name } ' ), liveness = 'eager'
961966 )
962- base_dict [f'Xglobal{ name } ' ] = LocalVec (
967+ base_dict [f'Xglobal{ name } ' ] = CallbackVec (
963968 sreg .make_name (prefix = f'Xglobal{ name } ' )
964969 )
965- base_dict [f'xglobal{ name } ' ] = GlobalVec (
970+ base_dict [f'xglobal{ name } ' ] = Vec (
966971 sreg .make_name (prefix = f'xglobal{ name } ' )
967972 )
968- base_dict [f'blocal{ name } ' ] = LocalVec (
973+ base_dict [f'blocal{ name } ' ] = CallbackVec (
969974 sreg .make_name (prefix = f'blocal{ name } ' ), liveness = 'eager'
970975 )
971- base_dict [f'bglobal{ name } ' ] = GlobalVec (
976+ base_dict [f'bglobal{ name } ' ] = Vec (
972977 sreg .make_name (prefix = f'bglobal{ name } ' )
973978 )
974979 base_dict [f'da{ name } ' ] = DM (
@@ -1021,6 +1026,12 @@ def _setup(self):
10211026 global_x = petsc_call ('DMCreateGlobalVector' ,
10221027 [dmda , Byref (sobjs ['xglobal' ])])
10231028
1029+ local_x = petsc_call ('DMCreateLocalVector' ,
1030+ [dmda , Byref (sobjs ['xlocal' ])])
1031+
1032+ get_local_size = petsc_call ('VecGetSize' ,
1033+ [sobjs ['xlocal' ], Byref (sobjs ['localsize' ])])
1034+
10241035 global_b = petsc_call ('DMCreateGlobalVector' ,
10251036 [dmda , Byref (sobjs ['bglobal' ])])
10261037
@@ -1078,6 +1089,8 @@ def _setup(self):
10781089 snes_set_jac ,
10791090 snes_set_type ,
10801091 global_x ,
1092+ local_x ,
1093+ get_local_size ,
10811094 global_b ,
10821095 snes_get_ksp ,
10831096 ksp_set_tols ,
@@ -1235,9 +1248,6 @@ def _execute_solve(self):
12351248
12361249 rhs_call = petsc_call (rhs_callback .name , [sobjs ['dmda' ], sobjs ['bglobal' ]])
12371250
1238- local_x = petsc_call ('DMCreateLocalVector' ,
1239- [dmda , Byref (sobjs ['xlocal' ])])
1240-
12411251 vec_replace_array = self .timedep .replace_array (target )
12421252
12431253 dm_local_to_global_x = petsc_call (
@@ -1253,13 +1263,15 @@ def _execute_solve(self):
12531263 dmda , sobjs ['xglobal' ], insert_vals , sobjs ['xlocal' ]]
12541264 )
12551265
1266+ vec_reset_array = self .timedep .reset_array (target )
1267+
12561268 run_solver_calls = (struct_assignment ,) + (
12571269 rhs_call ,
1258- local_x
12591270 ) + vec_replace_array + (
12601271 dm_local_to_global_x ,
12611272 snes_solve ,
12621273 dm_global_to_local_x ,
1274+ vec_reset_array ,
12631275 BlankLine ,
12641276 )
12651277 return List (body = run_solver_calls )
@@ -1402,7 +1414,16 @@ def replace_array(self, target):
14021414 target .function ._C_field_data , target .function ._C_symbol
14031415 )
14041416 xlocal = sobjs .get (f'xlocal{ target .name } ' , sobjs ['xlocal' ])
1405- return (petsc_call ('VecReplaceArray' , [xlocal , field_from_ptr ]),)
1417+ return (petsc_call ('VecPlaceArray' , [xlocal , field_from_ptr ]),)
1418+
1419+ def reset_array (self , target ):
1420+ """
1421+ """
1422+ sobjs = self .sobjs
1423+ xlocal = sobjs .get (f'xlocal{ target .name } ' , sobjs ['xlocal' ])
1424+ return (
1425+ petsc_call ('VecResetArray' , [xlocal ])
1426+ )
14061427
14071428 def assign_time_iters (self , struct ):
14081429 return []
@@ -1530,15 +1551,14 @@ def replace_array(self, target):
15301551
15311552 caster = cast (target .dtype , '*' )
15321553 return (
1533- petsc_call ('VecGetSize' , [xlocal , Byref (sobjs ['localsize' ])]),
15341554 DummyExpr (
15351555 start_ptr ,
15361556 caster (
15371557 FieldFromPointer (target ._C_field_data , target ._C_symbol )
15381558 ) + Mul (target_time , sobjs ['localsize' ]),
15391559 init = True
15401560 ),
1541- petsc_call ('VecReplaceArray ' , [xlocal , start_ptr ])
1561+ petsc_call ('VecPlaceArray ' , [xlocal , start_ptr ])
15421562 )
15431563 return super ().replace_array (target )
15441564
0 commit comments