Skip to content

Commit f4ee1b9

Browse files
committed
compiler: Initialize and Finalize PETSc once and command line args
1 parent 5505a12 commit f4ee1b9

18 files changed

Lines changed: 225 additions & 63 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from devito.tools import Pickable, Tag, frozendict
1212
from devito.types import (Eq, Inc, ReduceMax, ReduceMin,
1313
relational_min)
14-
from devito.types.equation import InjectSolveEq
14+
from devito.types.equation import PetscEq
1515

1616
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
17-
'identity_mapper', 'OpInjectSolve']
17+
'identity_mapper', 'OpPetsc']
1818

1919

2020
class IREq(sympy.Eq, Pickable):
@@ -105,7 +105,7 @@ def detect(cls, expr):
105105
Inc: OpInc,
106106
ReduceMax: OpMax,
107107
ReduceMin: OpMin,
108-
InjectSolveEq: OpInjectSolve
108+
PetscEq: OpPetsc
109109
}
110110
try:
111111
return reduction_mapper[type(expr)]
@@ -122,7 +122,7 @@ def detect(cls, expr):
122122
OpInc = Operation('+')
123123
OpMax = Operation('max')
124124
OpMin = Operation('min')
125-
OpInjectSolve = Operation('solve')
125+
OpPetsc = Operation('solve')
126126

127127

128128
identity_mapper = {

devito/ir/iet/algorithms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from devito.ir.iet import (Expression, Increment, Iteration, List, Conditional, SyncSpot,
44
Section, HaloSpot, ExpressionBundle)
55
from devito.tools import timed_pass
6-
from devito.petsc.types import LinearSolveExpr
6+
from devito.petsc.types import MetaData
77
from devito.petsc.iet.utils import petsc_iet_mapper
88

99
__all__ = ['iet_build']
@@ -26,7 +26,7 @@ def iet_build(stree):
2626
for e in i.exprs:
2727
if e.is_Increment:
2828
exprs.append(Increment(e))
29-
elif isinstance(e.rhs, LinearSolveExpr):
29+
elif isinstance(e.rhs, MetaData):
3030
exprs.append(petsc_iet_mapper[e.operation](e, operation=e.operation))
3131
else:
3232
exprs.append(Expression(e, operation=e.operation))

devito/petsc/iet/nodes.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
11
from devito.ir.iet import Expression, Callback, FixedArgsCallable, Call
2-
from devito.ir.equations import OpInjectSolve
2+
from devito.ir.equations import OpPetsc
33

44

5-
class LinearSolverExpression(Expression):
5+
class PetscMetaData(Expression):
66
"""
7-
Base class for general expressions required by a
8-
matrix-free linear solve of the form Ax=b.
7+
Base class for general expressions required to run a PETSc solver.
98
"""
10-
pass
11-
12-
13-
class InjectSolveDummy(LinearSolverExpression):
14-
"""
15-
Placeholder expression to run the iterative solver.
16-
"""
17-
def __init__(self, expr, pragmas=None, operation=OpInjectSolve):
9+
def __init__(self, expr, pragmas=None, operation=OpPetsc):
1810
super().__init__(expr, pragmas=pragmas, operation=operation)
1911

2012

devito/petsc/iet/passes.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import cgen as c
2+
import numpy as np
23

34
from devito.passes.iet.engine import iet_pass
4-
from devito.ir.iet import Transformer, MapNodes, Iteration, BlankLine
5+
from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine,
6+
FindNodes, Call, CallableBody)
57
from devito.symbolics import Byref, Macro
6-
from devito.petsc.types import (PetscMPIInt, PetscErrorCode)
7-
from devito.petsc.iet.nodes import InjectSolveDummy
8+
from devito.types.basic import DataSymbol
9+
from devito.petsc.types import (PetscMPIInt, PetscErrorCode, Initialize,
10+
Finalize, ArgvSymbol)
11+
from devito.petsc.types.macros import petsc_func_begin_user
12+
from devito.petsc.iet.nodes import PetscMetaData
813
from devito.petsc.utils import core_metadata
914
from devito.petsc.iet.routines import (CallbackBuilder, BaseObjectBuilder, BaseSetup,
1015
Solver, TimeDependent, NonTimeDependent)
@@ -14,14 +19,23 @@
1419
@iet_pass
1520
def lower_petsc(iet, **kwargs):
1621
# Check if PETScSolve was used
17-
injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy,
22+
injectsolve_mapper = MapNodes(Iteration, PetscMetaData,
1823
'groupby').visit(iet)
1924

2025
if not injectsolve_mapper:
2126
return iet, {}
2227

28+
metadata = core_metadata()
29+
30+
data = FindNodes(PetscMetaData).visit(iet)
31+
32+
if any(filter(lambda i: isinstance(i.expr.rhs, Initialize), data)):
33+
return initialize(iet), metadata
34+
35+
if any(filter(lambda i: isinstance(i.expr.rhs, Finalize), data)):
36+
return finalize(iet), metadata
37+
2338
targets = [i.expr.rhs.target for (i,) in injectsolve_mapper.values()]
24-
init = init_petsc(**kwargs)
2539

2640
# Assumption is that all targets have the same grid so can use any target here
2741
objs = build_core_objects(targets[-1], **kwargs)
@@ -47,26 +61,37 @@ def lower_petsc(iet, **kwargs):
4761
iet = Transformer(subs).visit(iet)
4862

4963
body = core + tuple(setup) + (BlankLine,) + iet.body.body
50-
body = iet.body._rebuild(
51-
init=init, body=body,
52-
frees=(petsc_call('PetscFinalize', []),)
53-
)
64+
body = iet.body._rebuild(body=body)
5465
iet = iet._rebuild(body=body)
55-
metadata = core_metadata()
5666
metadata.update({'efuncs': tuple(efuncs.values())})
5767

5868
return iet, metadata
5969

6070

61-
def init_petsc(**kwargs):
62-
# Initialize PETSc -> for now, assuming all solver options have to be
63-
# specified via the parameters dict in PETScSolve
64-
# TODO: Are users going to be able to use PETSc command line arguments?
65-
# In firedrake, they have an options_prefix for each solver, enabling the use
66-
# of command line options
67-
initialize = petsc_call('PetscInitialize', [Null, Null, Null, Null])
71+
def initialize(iet):
72+
# should be int because the correct type for argc is a C int
73+
# and not a int32
74+
argc = DataSymbol(name='argc', dtype=np.int32)
75+
argv = ArgvSymbol(name='argv')
76+
Help = Macro('help')
6877

69-
return petsc_func_begin_user, initialize
78+
help_string = c.Line(r'static char help[] = "This is help text.\n";')
79+
80+
init_body = petsc_call('PetscInitialize', [Byref(argc), Byref(argv), Null, Help])
81+
init_body = CallableBody(
82+
body=(petsc_func_begin_user, help_string, init_body),
83+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
84+
)
85+
return iet._rebuild(body=init_body)
86+
87+
88+
def finalize(iet):
89+
finalize_body = petsc_call('PetscFinalize', [])
90+
finalize_body = CallableBody(
91+
body=(petsc_func_begin_user, finalize_body),
92+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
93+
)
94+
return iet._rebuild(body=finalize_body)
7095

7196

7297
def make_core_petsc_calls(objs, **kwargs):
@@ -76,10 +101,7 @@ def make_core_petsc_calls(objs, **kwargs):
76101

77102

78103
def build_core_objects(target, **kwargs):
79-
if kwargs['options']['mpi']:
80-
communicator = target.grid.distributor._obj_comm
81-
else:
82-
communicator = 'PETSC_COMM_SELF'
104+
communicator = 'PETSC_COMM_WORLD'
83105

84106
return {
85107
'size': PetscMPIInt(name='size'),
@@ -128,9 +150,6 @@ def __init__(self, injectsolve, objs, iters, **kwargs):
128150
)
129151

130152

153+
# Move these to types folder
131154
Null = Macro('NULL')
132155
void = 'void'
133-
134-
135-
# TODO: Don't use c.Line here?
136-
petsc_func_begin_user = c.Line('PetscFunctionBeginUser;')

devito/petsc/iet/routines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from devito.petsc.types import PETScArray
1616
from devito.petsc.iet.nodes import (PETScCallable, FormFunctionCallback,
17-
MatVecCallback, InjectSolveDummy)
17+
MatVecCallback, PetscMetaData)
1818
from devito.petsc.iet.utils import petsc_call, petsc_struct
1919
from devito.petsc.utils import solver_mapper
2020
from devito.petsc.types import (DM, CallbackDM, Mat, LocalVec, GlobalVec, KSP, PC,
@@ -768,7 +768,7 @@ def _spatial_loop_nest(self, iters, injectsolve):
768768
spatial_body = []
769769
for tree in retrieve_iteration_tree(iters[0]):
770770
root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0]
771-
if injectsolve in FindNodes(InjectSolveDummy).visit(root):
771+
if injectsolve in FindNodes(PetscMetaData).visit(root):
772772
spatial_body.append(root)
773773
return spatial_body
774774

devito/petsc/iet/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from devito.petsc.iet.nodes import InjectSolveDummy, PETScCall
2-
from devito.ir.equations import OpInjectSolve
1+
from devito.petsc.iet.nodes import PetscMetaData, PETScCall
2+
from devito.ir.equations import OpPetsc
33

44

55
def petsc_call(specific_call, call_args):
@@ -19,4 +19,4 @@ def petsc_struct(name, fields, pname, liveness='lazy'):
1919

2020
# Mapping special Eq operations to their corresponding IET Expression subclass types.
2121
# These operations correspond to subclasses of Eq utilised within PETScSolve.
22-
petsc_iet_mapper = {OpInjectSolve: InjectSolveDummy}
22+
petsc_iet_mapper = {OpPetsc: PetscMetaData}

devito/petsc/initialize.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import sys
3+
from ctypes import POINTER, cast, c_char
4+
import atexit
5+
6+
from devito import Operator
7+
from devito.types import Symbol
8+
from devito.types.equation import PetscEq
9+
from devito.petsc.types import Initialize, Finalize
10+
11+
global _petsc_initialized
12+
_petsc_initialized = False
13+
14+
15+
def PetscInitialize():
16+
global _petsc_initialized
17+
if not _petsc_initialized:
18+
dummy = Symbol(name='d')
19+
# TODO: Potentially just use cgen + the compiler machinery in Devito
20+
# to generate these "dummy_ops" instead of using the Operator class.
21+
# This would prevent circular imports when initializing during import
22+
# from the PETSc module.
23+
op_init = Operator(
24+
[PetscEq(dummy, Initialize(dummy))],
25+
name='kernel_init', opt='noop'
26+
)
27+
op_finalize = Operator(
28+
[PetscEq(dummy, Finalize(dummy))],
29+
name='kernel_finalize', opt='noop'
30+
)
31+
32+
# `argv_bytes` must be a list so the memory address persists
33+
# `os.fsencode` should be preferred over `string().encode('utf-8')`
34+
# in case there is some system specific encoding in use
35+
argv_bytes = list(map(os.fsencode, sys.argv))
36+
argv_pointer = (POINTER(c_char)*len(sys.argv))(
37+
*map(lambda s: cast(s, POINTER(c_char)), argv_bytes)
38+
)
39+
op_init.apply(argc=len(sys.argv), argv=argv_pointer)
40+
41+
atexit.register(op_finalize.apply)
42+
_petsc_initialized = True

devito/petsc/solve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from devito.finite_differences.differentiable import Mul
66
from devito.finite_differences.derivative import Derivative
77
from devito.types import Eq, Symbol, SteppingDimension, TimeFunction
8-
from devito.types.equation import InjectSolveEq
8+
from devito.types.equation import PetscEq
99
from devito.operations.solve import eval_time_derivatives
1010
from devito.symbolics import retrieve_functions
1111
from devito.tools import as_tuple
@@ -65,7 +65,7 @@ def PETScSolve(eqns, target, solver_parameters=None, **kwargs):
6565
)
6666
# Placeholder equation for inserting calls to the solver and generating
6767
# correct time loop etc
68-
inject_solve = InjectSolveEq(target, LinearSolveExpr(
68+
inject_solve = PetscEq(target, LinearSolveExpr(
6969
expr=tuple(funcs),
7070
target=target,
7171
solver_parameters=solver_parameters,

devito/petsc/types/macros.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import cgen as c
2+
3+
4+
# TODO: Don't use c.Line here?
5+
petsc_func_begin_user = c.Line('PetscFunctionBeginUser;')

devito/petsc/types/object.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from ctypes import POINTER
1+
from ctypes import POINTER, c_char
22

33
from devito.tools import CustomDtype, dtype_to_cstr
44
from devito.types import LocalObject, CCompositeObject, ModuloDimension, TimeDimension
5+
from devito.types.basic import DataSymbol
56
from devito.symbolics import Byref
67

78
from devito.petsc.iet.utils import petsc_call
@@ -209,3 +210,9 @@ class StartPtr(LocalObject):
209210
def __init__(self, name, dtype):
210211
super().__init__(name=name)
211212
self.dtype = CustomDtype(dtype_to_cstr(dtype), modifier=' *')
213+
214+
215+
class ArgvSymbol(DataSymbol):
216+
@property
217+
def _C_ctype(self):
218+
return POINTER(POINTER(c_char))

0 commit comments

Comments
 (0)