11import cgen as c
2+ import numpy as np
23
34from devito .passes .iet .engine import iet_pass
4- from devito .ir .iet import Transformer , MapNodes , Iteration , BlankLine
5+ from devito .ir .iet import (Transformer , MapNodes , Iteration , BlankLine ,
6+ FindNodes , Call , CallableBody )
57from devito .symbolics import Byref , Macro
6- from devito .petsc .types import (PetscMPIInt , PetscErrorCode )
7- from devito .petsc .iet .nodes import InjectSolveDummy
8+ from devito .types .basic import DataSymbol
9+ from devito .petsc .types import (PetscMPIInt , PetscErrorCode , Initialize ,
10+ Finalize , ArgvSymbol )
11+ from devito .petsc .types .macros import petsc_func_begin_user
12+ from devito .petsc .iet .nodes import PetscMetaData
813from devito .petsc .utils import core_metadata
914from devito .petsc .iet .routines import (CallbackBuilder , BaseObjectBuilder , BaseSetup ,
1015 Solver , TimeDependent , NonTimeDependent )
1419@iet_pass
1520def lower_petsc (iet , ** kwargs ):
1621 # Check if PETScSolve was used
17- injectsolve_mapper = MapNodes (Iteration , InjectSolveDummy ,
22+ injectsolve_mapper = MapNodes (Iteration , PetscMetaData ,
1823 'groupby' ).visit (iet )
1924
2025 if not injectsolve_mapper :
2126 return iet , {}
2227
28+ metadata = core_metadata ()
29+
30+ data = FindNodes (PetscMetaData ).visit (iet )
31+
32+ if any (filter (lambda i : isinstance (i .expr .rhs , Initialize ), data )):
33+ return initialize (iet ), metadata
34+
35+ if any (filter (lambda i : isinstance (i .expr .rhs , Finalize ), data )):
36+ return finalize (iet ), metadata
37+
2338 targets = [i .expr .rhs .target for (i ,) in injectsolve_mapper .values ()]
24- init = init_petsc (** kwargs )
2539
2640 # Assumption is that all targets have the same grid so can use any target here
2741 objs = build_core_objects (targets [- 1 ], ** kwargs )
@@ -47,26 +61,37 @@ def lower_petsc(iet, **kwargs):
4761 iet = Transformer (subs ).visit (iet )
4862
4963 body = core + tuple (setup ) + (BlankLine ,) + iet .body .body
50- body = iet .body ._rebuild (
51- init = init , body = body ,
52- frees = (petsc_call ('PetscFinalize' , []),)
53- )
64+ body = iet .body ._rebuild (body = body )
5465 iet = iet ._rebuild (body = body )
55- metadata = core_metadata ()
5666 metadata .update ({'efuncs' : tuple (efuncs .values ())})
5767
5868 return iet , metadata
5969
6070
61- def init_petsc (** kwargs ):
62- # Initialize PETSc -> for now, assuming all solver options have to be
63- # specified via the parameters dict in PETScSolve
64- # TODO: Are users going to be able to use PETSc command line arguments?
65- # In firedrake, they have an options_prefix for each solver, enabling the use
66- # of command line options
67- initialize = petsc_call ('PetscInitialize' , [Null , Null , Null , Null ])
71+ def initialize (iet ):
72+ # should be int because the correct type for argc is a C int
73+ # and not a int32
74+ argc = DataSymbol (name = 'argc' , dtype = np .int32 )
75+ argv = ArgvSymbol (name = 'argv' )
76+ Help = Macro ('help' )
6877
69- return petsc_func_begin_user , initialize
78+ help_string = c .Line (r'static char help[] = "This is help text.\n";' )
79+
80+ init_body = petsc_call ('PetscInitialize' , [Byref (argc ), Byref (argv ), Null , Help ])
81+ init_body = CallableBody (
82+ body = (petsc_func_begin_user , help_string , init_body ),
83+ retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
84+ )
85+ return iet ._rebuild (body = init_body )
86+
87+
88+ def finalize (iet ):
89+ finalize_body = petsc_call ('PetscFinalize' , [])
90+ finalize_body = CallableBody (
91+ body = (petsc_func_begin_user , finalize_body ),
92+ retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
93+ )
94+ return iet ._rebuild (body = finalize_body )
7095
7196
7297def make_core_petsc_calls (objs , ** kwargs ):
@@ -76,10 +101,7 @@ def make_core_petsc_calls(objs, **kwargs):
76101
77102
78103def build_core_objects (target , ** kwargs ):
79- if kwargs ['options' ]['mpi' ]:
80- communicator = target .grid .distributor ._obj_comm
81- else :
82- communicator = 'PETSC_COMM_SELF'
104+ communicator = 'PETSC_COMM_WORLD'
83105
84106 return {
85107 'size' : PetscMPIInt (name = 'size' ),
@@ -128,9 +150,6 @@ def __init__(self, injectsolve, objs, iters, **kwargs):
128150 )
129151
130152
153+ # Move these to types folder
131154Null = Macro ('NULL' )
132155void = 'void'
133-
134-
135- # TODO: Don't use c.Line here?
136- petsc_func_begin_user = c .Line ('PetscFunctionBeginUser;' )
0 commit comments