@@ -47,6 +47,7 @@ def __init__(self, **kwargs):
4747 self ._matvecs = []
4848 self ._formfuncs = []
4949 self ._formrhs = []
50+ self ._initialguesses = []
5051
5152 self ._make_core ()
5253 self ._efuncs = self ._uxreplace_efuncs ()
@@ -88,6 +89,10 @@ def formfuncs(self):
8889 def formrhs (self ):
8990 return self ._formrhs
9091
92+ @property
93+ def initialguesses (self ):
94+ return self ._initialguesses
95+
9196 @property
9297 def user_struct_callback (self ):
9398 return self ._user_struct_callback
@@ -97,6 +102,8 @@ def _make_core(self):
97102 self ._make_matvec (fielddata , fielddata .matvecs )
98103 self ._make_formfunc (fielddata )
99104 self ._make_formrhs (fielddata )
105+ if fielddata .initialguess :
106+ self ._make_initialguess (fielddata )
100107 self ._make_user_struct_callback ()
101108
102109 def _make_matvec (self , fielddata , matvecs , prefix = 'MatMult' ):
@@ -483,6 +490,84 @@ def _create_form_rhs_body(self, body, fielddata):
483490
484491 return Uxreplace (subs ).visit (formrhs_body )
485492
493+ def _make_initialguess (self , fielddata ):
494+ initguess = fielddata .initialguess
495+ sobjs = self .solver_objs
496+
497+ # Compile initital guess `eqns` into an IET via recursive compilation
498+ irs , _ = self .rcompile (
499+ initguess , options = {'mpi' : False }, sregistry = self .sregistry ,
500+ concretize_mapper = self .concretize_mapper
501+ )
502+ body_init_guess = self ._create_initial_guess_body (
503+ List (body = irs .uiet .body ), fielddata
504+ )
505+ objs = self .objs
506+ cb = PETScCallable (
507+ self .sregistry .make_name (prefix = 'FormInitialGuess' ),
508+ body_init_guess ,
509+ retval = objs ['err' ],
510+ parameters = (sobjs ['callbackdm' ], objs ['xloc' ])
511+ )
512+ self ._initialguesses .append (cb )
513+ self ._efuncs [cb .name ] = cb
514+
515+ def _create_initial_guess_body (self , body , fielddata ):
516+ linsolve_expr = self .injectsolve .expr .rhs
517+ objs = self .objs
518+ sobjs = self .solver_objs
519+
520+ dmda = sobjs ['callbackdm' ]
521+ ctx = objs ['dummyctx' ]
522+
523+ x_arr = fielddata .arrays ['x' ]
524+
525+ vec_get_array = petsc_call (
526+ 'VecGetArray' , [objs ['xloc' ], Byref (x_arr ._C_symbol )]
527+ )
528+
529+ dm_get_local_info = petsc_call (
530+ 'DMDAGetLocalInfo' , [dmda , Byref (linsolve_expr .localinfo )]
531+ )
532+
533+ body = self .timedep .uxreplace_time (body )
534+
535+ fields = self ._dummy_fields (body )
536+ self ._struct_params .extend (fields )
537+
538+ dm_get_app_context = petsc_call (
539+ 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
540+ )
541+
542+ vec_restore_array = petsc_call (
543+ 'VecRestoreArray' , [objs ['xloc' ], Byref (x_arr ._C_symbol )]
544+ )
545+
546+ body = body ._rebuild (body = body .body + (vec_restore_array ,))
547+
548+ stacks = (
549+ vec_get_array ,
550+ dm_get_app_context ,
551+ dm_get_local_info
552+ )
553+
554+ # Dereference function data in struct
555+ dereference_funcs = [Dereference (i , ctx ) for i in
556+ fields if isinstance (i .function , AbstractFunction )]
557+
558+ body = CallableBody (
559+ List (body = [body ]),
560+ init = (objs ['begin_user' ],),
561+ stacks = stacks + tuple (dereference_funcs ),
562+ retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
563+ )
564+
565+ # Replace non-function data with pointer to data in struct
566+ subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
567+ i in fields if not isinstance (i .function , AbstractFunction )}
568+
569+ return Uxreplace (subs ).visit (body )
570+
486571 def _make_user_struct_callback (self ):
487572 """
488573 This is the struct initialised inside the main kernel and
@@ -1250,6 +1335,12 @@ def _execute_solve(self):
12501335
12511336 vec_replace_array = self .timedep .replace_array (target )
12521337
1338+ if self .cbbuilder .initialguesses :
1339+ initguess = self .cbbuilder .initialguesses [0 ]
1340+ initguess_call = petsc_call (initguess .name , [dmda , sobjs ['xlocal' ]])
1341+ else :
1342+ initguess_call = None
1343+
12531344 dm_local_to_global_x = petsc_call (
12541345 'DMLocalToGlobal' , [dmda , sobjs ['xlocal' ], insert_vals ,
12551346 sobjs ['xglobal' ]]
@@ -1268,6 +1359,7 @@ def _execute_solve(self):
12681359 run_solver_calls = (struct_assignment ,) + (
12691360 rhs_call ,
12701361 ) + vec_replace_array + (
1362+ initguess_call ,
12711363 dm_local_to_global_x ,
12721364 snes_solve ,
12731365 dm_global_to_local_x ,
0 commit comments