Skip to content

Commit fb4b96f

Browse files
authored
Compiler: Add basic initguess callback (#72)
* Compiler: Add basic initguess callback - to be improved
1 parent 735da03 commit fb4b96f

5 files changed

Lines changed: 195 additions & 5 deletions

File tree

devito/petsc/iet/routines.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, **kwargs):
4747
self._matvecs = []
4848
self._formfuncs = []
4949
self._formrhs = []
50+
self._initialguesses = []
5051

5152
self._make_core()
5253
self._efuncs = self._uxreplace_efuncs()
@@ -88,6 +89,10 @@ def formfuncs(self):
8889
def formrhs(self):
8990
return self._formrhs
9091

92+
@property
93+
def initialguesses(self):
94+
return self._initialguesses
95+
9196
@property
9297
def user_struct_callback(self):
9398
return self._user_struct_callback
@@ -97,6 +102,8 @@ def _make_core(self):
97102
self._make_matvec(fielddata, fielddata.matvecs)
98103
self._make_formfunc(fielddata)
99104
self._make_formrhs(fielddata)
105+
if fielddata.initialguess:
106+
self._make_initialguess(fielddata)
100107
self._make_user_struct_callback()
101108

102109
def _make_matvec(self, fielddata, matvecs, prefix='MatMult'):
@@ -483,6 +490,84 @@ def _create_form_rhs_body(self, body, fielddata):
483490

484491
return Uxreplace(subs).visit(formrhs_body)
485492

493+
def _make_initialguess(self, fielddata):
494+
initguess = fielddata.initialguess
495+
sobjs = self.solver_objs
496+
497+
# Compile initital guess `eqns` into an IET via recursive compilation
498+
irs, _ = self.rcompile(
499+
initguess, options={'mpi': False}, sregistry=self.sregistry,
500+
concretize_mapper=self.concretize_mapper
501+
)
502+
body_init_guess = self._create_initial_guess_body(
503+
List(body=irs.uiet.body), fielddata
504+
)
505+
objs = self.objs
506+
cb = PETScCallable(
507+
self.sregistry.make_name(prefix='FormInitialGuess'),
508+
body_init_guess,
509+
retval=objs['err'],
510+
parameters=(sobjs['callbackdm'], objs['xloc'])
511+
)
512+
self._initialguesses.append(cb)
513+
self._efuncs[cb.name] = cb
514+
515+
def _create_initial_guess_body(self, body, fielddata):
516+
linsolve_expr = self.injectsolve.expr.rhs
517+
objs = self.objs
518+
sobjs = self.solver_objs
519+
520+
dmda = sobjs['callbackdm']
521+
ctx = objs['dummyctx']
522+
523+
x_arr = fielddata.arrays['x']
524+
525+
vec_get_array = petsc_call(
526+
'VecGetArray', [objs['xloc'], Byref(x_arr._C_symbol)]
527+
)
528+
529+
dm_get_local_info = petsc_call(
530+
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
531+
)
532+
533+
body = self.timedep.uxreplace_time(body)
534+
535+
fields = self._dummy_fields(body)
536+
self._struct_params.extend(fields)
537+
538+
dm_get_app_context = petsc_call(
539+
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
540+
)
541+
542+
vec_restore_array = petsc_call(
543+
'VecRestoreArray', [objs['xloc'], Byref(x_arr._C_symbol)]
544+
)
545+
546+
body = body._rebuild(body=body.body + (vec_restore_array,))
547+
548+
stacks = (
549+
vec_get_array,
550+
dm_get_app_context,
551+
dm_get_local_info
552+
)
553+
554+
# Dereference function data in struct
555+
dereference_funcs = [Dereference(i, ctx) for i in
556+
fields if isinstance(i.function, AbstractFunction)]
557+
558+
body = CallableBody(
559+
List(body=[body]),
560+
init=(objs['begin_user'],),
561+
stacks=stacks+tuple(dereference_funcs),
562+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
563+
)
564+
565+
# Replace non-function data with pointer to data in struct
566+
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
567+
i in fields if not isinstance(i.function, AbstractFunction)}
568+
569+
return Uxreplace(subs).visit(body)
570+
486571
def _make_user_struct_callback(self):
487572
"""
488573
This is the struct initialised inside the main kernel and
@@ -1250,6 +1335,12 @@ def _execute_solve(self):
12501335

12511336
vec_replace_array = self.timedep.replace_array(target)
12521337

1338+
if self.cbbuilder.initialguesses:
1339+
initguess = self.cbbuilder.initialguesses[0]
1340+
initguess_call = petsc_call(initguess.name, [dmda, sobjs['xlocal']])
1341+
else:
1342+
initguess_call = None
1343+
12531344
dm_local_to_global_x = petsc_call(
12541345
'DMLocalToGlobal', [dmda, sobjs['xlocal'], insert_vals,
12551346
sobjs['xglobal']]
@@ -1268,6 +1359,7 @@ def _execute_solve(self):
12681359
run_solver_calls = (struct_assignment,) + (
12691360
rhs_call,
12701361
) + vec_replace_array + (
1362+
initguess_call,
12711363
dm_local_to_global_x,
12721364
snes_solve,
12731365
dm_global_to_local_x,

devito/petsc/solve.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,18 @@ def generate_field_data(self, eqns, target, arrays):
5757
)
5858
matvecs = [self.build_matvec_eqns(eq, target, arrays) for eq in eqns]
5959

60+
initialguess = [
61+
eq for eq in
62+
(self.make_initial_guess(e, target, arrays) for e in eqns)
63+
if eq is not None
64+
]
65+
6066
return FieldData(
6167
target=target,
6268
matvecs=matvecs,
6369
formfuncs=formfuncs,
6470
formrhs=formrhs,
71+
initialguess=initialguess,
6572
arrays=arrays
6673
)
6774

@@ -99,6 +106,22 @@ def make_rhs(self, eq, b, arrays):
99106
rhs = 0. if isinstance(eq, EssentialBC) else b.subs(self.time_mapper)
100107
return Eq(arrays['b'], rhs, subdomain=eq.subdomain)
101108

109+
def make_initial_guess(self, eq, target, arrays):
110+
"""
111+
Enforce initial guess to satisfy essential BCs.
112+
# TODO: For time-stepping, only enforce these once outside the time loop
113+
and use the previous time-step solution as the initial guess for next time step.
114+
# TODO: Extend this to "coupled".
115+
"""
116+
if isinstance(eq, EssentialBC):
117+
assert eq.lhs == target
118+
return Eq(
119+
arrays['x'], eq.rhs,
120+
subdomain=eq.subdomain
121+
)
122+
else:
123+
return None
124+
102125
def generate_arrays(self, target):
103126
return {
104127
p: PETScArray(name=f'{p}_{target.name}',

devito/petsc/types/types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def eval(cls, *args):
132132

133133
class FieldData:
134134
def __init__(self, target=None, matvecs=None, formfuncs=None, formrhs=None,
135-
arrays=None, **kwargs):
135+
initialguess=None, arrays=None, **kwargs):
136136
self._target = kwargs.get('target', target)
137137

138138
petsc_precision = dtype_mapper[petsc_variables['PETSC_PRECISION']]
@@ -145,6 +145,7 @@ def __init__(self, target=None, matvecs=None, formfuncs=None, formrhs=None,
145145
self._matvecs = matvecs
146146
self._formfuncs = formfuncs
147147
self._formrhs = formrhs
148+
self._initialguess = initialguess
148149
self._arrays = arrays
149150

150151
@property
@@ -163,6 +164,10 @@ def formfuncs(self):
163164
def formrhs(self):
164165
return self._formrhs
165166

167+
@property
168+
def initialguess(self):
169+
return self._initialguess
170+
166171
@property
167172
def arrays(self):
168173
return self._arrays

examples/seismic/tutorials/13_LSRTM_acoustic.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
},
305305
{
306306
"cell_type": "code",
307-
"execution_count": 6,
307+
"execution_count": null,
308308
"metadata": {},
309309
"outputs": [],
310310
"source": [
@@ -329,12 +329,13 @@
329329
" dm_true = (solver.model.vp.data**(-2) - model0.vp.data**(-2))\n",
330330
" \n",
331331
" objective = 0.\n",
332+
" u0 = None\n",
332333
" for i in range(nshots):\n",
333334
" \n",
334335
" #Observed Data using Born's operator\n",
335336
" geometry.src_positions[0, :] = source_locations[i, :]\n",
336337
"\n",
337-
" _, u0, _ = solver.forward(vp=model0.vp, save=True)\n",
338+
" _, u0, _ = solver.forward(vp=model0.vp, save=True, u=u0)\n",
338339
" \n",
339340
" _, _, _,_ = solver.jacobian(dm_true, vp=model0.vp, rec = d_obs)\n",
340341
" \n",

tests/test_petsc.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
from conftest import skipif
66
from devito import (Grid, Function, TimeFunction, Eq, Operator,
7-
configuration, norm, switchconfig)
7+
configuration, norm, switchconfig, SubDomain)
88
from devito.ir.iet import (Call, ElementalFunction,
99
FindNodes, retrieve_iteration_tree)
1010
from devito.types import Constant, LocalCompositeObject
1111
from devito.passes.iet.languages.C import CDataManager
1212
from devito.petsc.types import (DM, Mat, Vec, PetscMPIInt, KSP,
1313
PC, KSPConvergedReason, PETScArray,
1414
LinearSolveExpr, FieldData, MultipleFieldData)
15-
from devito.petsc.solve import PETScSolve, separate_eqn, centre_stencil
15+
from devito.petsc.solve import (PETScSolve, separate_eqn, centre_stencil,
16+
EssentialBC)
1617
from devito.petsc.iet.nodes import Expression
1718
from devito.petsc.initialize import PetscInitialize
1819

@@ -786,6 +787,74 @@ def test_solve_output():
786787
assert np.allclose(u.data, v.data)
787788

788789

790+
@skipif('petsc')
791+
def test_essential_bcs():
792+
"""
793+
Verify that PETScSolve returns the correct output with
794+
essential boundary conditions.
795+
"""
796+
class SubTop(SubDomain):
797+
name = 'subtop'
798+
799+
def define(self, dimensions):
800+
x, y = dimensions
801+
return {x: x, y: ('right', 1)}
802+
sub1 = SubTop()
803+
804+
class SubBottom(SubDomain):
805+
name = 'subbottom'
806+
807+
def define(self, dimensions):
808+
x, y = dimensions
809+
return {x: x, y: ('left', 1)}
810+
sub2 = SubBottom()
811+
812+
class SubLeft(SubDomain):
813+
name = 'subleft'
814+
815+
def define(self, dimensions):
816+
x, y = dimensions
817+
return {x: ('left', 1), y: y}
818+
sub3 = SubLeft()
819+
820+
class SubRight(SubDomain):
821+
name = 'subright'
822+
823+
def define(self, dimensions):
824+
x, y = dimensions
825+
return {x: ('right', 1), y: y}
826+
sub4 = SubRight()
827+
828+
subdomains = (sub1, sub2, sub3, sub4)
829+
grid = Grid(shape=(11, 11), subdomains=subdomains, dtype=np.float64)
830+
831+
u = Function(name='u', grid=grid, space_order=2)
832+
v = Function(name='v', grid=grid, space_order=2)
833+
834+
# Solving Ax=b where A is the identity matrix
835+
v.data[:] = 5.0
836+
eqn = Eq(u, v)
837+
838+
bcs = [EssentialBC(u, 1., subdomain=sub1)] # top
839+
bcs += [EssentialBC(u, 2., subdomain=sub2)] # bottom
840+
bcs += [EssentialBC(u, 3., subdomain=sub3)] # left
841+
bcs += [EssentialBC(u, 4., subdomain=sub4)] # right
842+
843+
petsc = PETScSolve([eqn]+bcs, target=u)
844+
845+
with switchconfig(language='petsc'):
846+
op = Operator(petsc)
847+
op.apply()
848+
849+
# Check u is equal to v on the interior
850+
assert np.allclose(u.data[1:-1, 1:-1], v.data[1:-1, 1:-1])
851+
# Check u satisfies the boundary conditions
852+
assert np.allclose(u.data[1:-1, -1], 1.0) # top
853+
assert np.allclose(u.data[1:-1, 0], 2.0) # bottom
854+
assert np.allclose(u.data[0, 1:-1], 3.0) # left
855+
assert np.allclose(u.data[-1, 1:-1], 4.0) # right
856+
857+
789858
class TestCoupledLinear:
790859
# The coupled interface can be used even for uncoupled problems, meaning
791860
# the equations will be solved within a single matrix system.

0 commit comments

Comments
 (0)