Skip to content

Commit 6cdb7d0

Browse files
authored
compiler: Temp fix for memory leaks. (#70)
- Temp fix until VecReplaceArray works properly in conjunction with PetscMemoryAllocator
1 parent f61d4dd commit 6cdb7d0

7 files changed

Lines changed: 98 additions & 75 deletions

File tree

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def devito_mpi_finalize():
9191
"""
9292
Finalize MPI, if initialized by Devito.
9393
"""
94-
global init_by_devito
94+
global init_by_devito # noqa
9595
if init_by_devito and MPI.Is_initialized() and not MPI.Is_finalized():
9696
MPI.Finalize()
9797

devito/petsc/iet/passes.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from devito.types.basic import DataSymbol
1212
from devito.tools import frozendict
1313
from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData,
14-
PointerIS, Mat, LocalVec, GlobalVec, CallbackMat, SNES,
14+
PointerIS, Mat, CallbackVec, Vec, CallbackMat, SNES,
1515
DummyArg, PetscInt, PointerDM, PointerMat, MatReuse,
1616
CallbackPointerIS, CallbackPointerDM, JacobianStruct,
1717
SubMatrixStruct, Initialize, Finalize, ArgvSymbol)
@@ -40,14 +40,13 @@ def lower_petsc(iet, **kwargs):
4040
f"{petsc_languages}, but got '{kwargs['language']}'"
4141
)
4242

43-
metadata = core_metadata()
4443
data = FindNodes(PetscMetaData).visit(iet)
4544

4645
if any(filter(lambda i: isinstance(i.expr.rhs, Initialize), data)):
47-
return initialize(iet), metadata
46+
return initialize(iet), core_metadata()
4847

4948
if any(filter(lambda i: isinstance(i.expr.rhs, Finalize), data)):
50-
return finalize(iet), metadata
49+
return finalize(iet), core_metadata()
5150

5251
unique_grids = {i.expr.rhs.grid for (i,) in injectsolve_mapper.values()}
5352
# Assumption is that all solves are on the same grid
@@ -79,7 +78,7 @@ def lower_petsc(iet, **kwargs):
7978
body = core + tuple(setup) + (BlankLine,) + iet.body.body
8079
body = iet.body._rebuild(body=body)
8180
iet = iet._rebuild(body=body)
82-
metadata.update({'efuncs': tuple(efuncs.values())})
81+
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}
8382
return iet, metadata
8483

8584

@@ -218,13 +217,13 @@ def populate_matrix_context(efuncs, objs):
218217
'rowidx': PetscInt('rowidx'),
219218
'colidx': PetscInt('colidx'),
220219
'J': Mat('J'),
221-
'X': GlobalVec('X'),
222-
'xloc': LocalVec('xloc'),
223-
'Y': GlobalVec('Y'),
224-
'yloc': LocalVec('yloc'),
225-
'F': GlobalVec('F'),
226-
'floc': LocalVec('floc'),
227-
'B': GlobalVec('B'),
220+
'X': Vec('X'),
221+
'xloc': CallbackVec('xloc'),
222+
'Y': Vec('Y'),
223+
'yloc': CallbackVec('yloc'),
224+
'F': Vec('F'),
225+
'floc': CallbackVec('floc'),
226+
'B': Vec('B'),
228227
'nfields': PetscInt('nfields'),
229228
'irow': PointerIS(name='irow'),
230229
'icol': PointerIS(name='icol'),

devito/petsc/iet/routines.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
MatShellSetOp, PetscMetaData)
1717
from devito.petsc.iet.utils import petsc_call, petsc_struct
1818
from devito.petsc.utils import solver_mapper
19-
from devito.petsc.types import (DM, Mat, LocalVec, GlobalVec, KSP, PC, SNES,
19+
from devito.petsc.types import (DM, Mat, CallbackVec, Vec, KSP, PC, SNES,
2020
PetscInt, StartPtr, PointerIS, PointerDM, VecScatter,
2121
DMCast, JacobianStructCast, JacobianStruct,
2222
SubMatrixStruct, CallbackDM)
@@ -448,8 +448,13 @@ def _create_form_rhs_body(self, body, fielddata):
448448
'VecRestoreArray', [sobjs['blocal'], Byref(b_arr._C_symbol)]
449449
)
450450

451+
dm_restore_local_bvec = petsc_call(
452+
'DMRestoreLocalVector', [dmda, Byref(sobjs['blocal'])]
453+
)
454+
451455
body = body._rebuild(body=body.body + (
452-
dm_local_to_global_begin, dm_local_to_global_end, vec_restore_array
456+
dm_local_to_global_begin, dm_local_to_global_end, vec_restore_array,
457+
dm_restore_local_bvec
453458
))
454459

455460
stacks = (
@@ -870,10 +875,10 @@ def _build(self):
870875
targets = self.fielddata.targets
871876
base_dict = {
872877
'Jac': Mat(sreg.make_name(prefix='J')),
873-
'xglobal': GlobalVec(sreg.make_name(prefix='xglobal')),
874-
'xlocal': LocalVec(sreg.make_name(prefix='xlocal')),
875-
'bglobal': GlobalVec(sreg.make_name(prefix='bglobal')),
876-
'blocal': LocalVec(sreg.make_name(prefix='blocal')),
878+
'xglobal': Vec(sreg.make_name(prefix='xglobal')),
879+
'xlocal': Vec(sreg.make_name(prefix='xlocal')),
880+
'bglobal': Vec(sreg.make_name(prefix='bglobal')),
881+
'blocal': CallbackVec(sreg.make_name(prefix='blocal')),
877882
'ksp': KSP(sreg.make_name(prefix='ksp')),
878883
'pc': PC(sreg.make_name(prefix='pc')),
879884
'snes': SNES(sreg.make_name(prefix='snes')),
@@ -939,9 +944,9 @@ def _extend_build(self, base_dict):
939944
name=f'{key}ctx',
940945
fields=objs['subctx'].fields,
941946
)
942-
base_dict[f'{key}X'] = LocalVec(f'{key}X')
943-
base_dict[f'{key}Y'] = LocalVec(f'{key}Y')
944-
base_dict[f'{key}F'] = LocalVec(f'{key}F')
947+
base_dict[f'{key}X'] = CallbackVec(f'{key}X')
948+
base_dict[f'{key}Y'] = CallbackVec(f'{key}Y')
949+
base_dict[f'{key}F'] = CallbackVec(f'{key}F')
945950

946951
return base_dict
947952

@@ -953,22 +958,22 @@ def _target_dependent(self, base_dict):
953958
base_dict[f'{name}_ptr'] = StartPtr(
954959
sreg.make_name(prefix=f'{name}_ptr'), t.dtype
955960
)
956-
base_dict[f'xlocal{name}'] = LocalVec(
961+
base_dict[f'xlocal{name}'] = CallbackVec(
957962
sreg.make_name(prefix=f'xlocal{name}'), liveness='eager'
958963
)
959-
base_dict[f'Fglobal{name}'] = LocalVec(
964+
base_dict[f'Fglobal{name}'] = CallbackVec(
960965
sreg.make_name(prefix=f'Fglobal{name}'), liveness='eager'
961966
)
962-
base_dict[f'Xglobal{name}'] = LocalVec(
967+
base_dict[f'Xglobal{name}'] = CallbackVec(
963968
sreg.make_name(prefix=f'Xglobal{name}')
964969
)
965-
base_dict[f'xglobal{name}'] = GlobalVec(
970+
base_dict[f'xglobal{name}'] = Vec(
966971
sreg.make_name(prefix=f'xglobal{name}')
967972
)
968-
base_dict[f'blocal{name}'] = LocalVec(
973+
base_dict[f'blocal{name}'] = CallbackVec(
969974
sreg.make_name(prefix=f'blocal{name}'), liveness='eager'
970975
)
971-
base_dict[f'bglobal{name}'] = GlobalVec(
976+
base_dict[f'bglobal{name}'] = Vec(
972977
sreg.make_name(prefix=f'bglobal{name}')
973978
)
974979
base_dict[f'da{name}'] = DM(
@@ -1021,6 +1026,12 @@ def _setup(self):
10211026
global_x = petsc_call('DMCreateGlobalVector',
10221027
[dmda, Byref(sobjs['xglobal'])])
10231028

1029+
local_x = petsc_call('DMCreateLocalVector',
1030+
[dmda, Byref(sobjs['xlocal'])])
1031+
1032+
get_local_size = petsc_call('VecGetSize',
1033+
[sobjs['xlocal'], Byref(sobjs['localsize'])])
1034+
10241035
global_b = petsc_call('DMCreateGlobalVector',
10251036
[dmda, Byref(sobjs['bglobal'])])
10261037

@@ -1078,6 +1089,8 @@ def _setup(self):
10781089
snes_set_jac,
10791090
snes_set_type,
10801091
global_x,
1092+
local_x,
1093+
get_local_size,
10811094
global_b,
10821095
snes_get_ksp,
10831096
ksp_set_tols,
@@ -1235,9 +1248,6 @@ def _execute_solve(self):
12351248

12361249
rhs_call = petsc_call(rhs_callback.name, [sobjs['dmda'], sobjs['bglobal']])
12371250

1238-
local_x = petsc_call('DMCreateLocalVector',
1239-
[dmda, Byref(sobjs['xlocal'])])
1240-
12411251
vec_replace_array = self.timedep.replace_array(target)
12421252

12431253
dm_local_to_global_x = petsc_call(
@@ -1253,13 +1263,15 @@ def _execute_solve(self):
12531263
dmda, sobjs['xglobal'], insert_vals, sobjs['xlocal']]
12541264
)
12551265

1266+
vec_reset_array = self.timedep.reset_array(target)
1267+
12561268
run_solver_calls = (struct_assignment,) + (
12571269
rhs_call,
1258-
local_x
12591270
) + vec_replace_array + (
12601271
dm_local_to_global_x,
12611272
snes_solve,
12621273
dm_global_to_local_x,
1274+
vec_reset_array,
12631275
BlankLine,
12641276
)
12651277
return List(body=run_solver_calls)
@@ -1402,7 +1414,16 @@ def replace_array(self, target):
14021414
target.function._C_field_data, target.function._C_symbol
14031415
)
14041416
xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal'])
1405-
return (petsc_call('VecReplaceArray', [xlocal, field_from_ptr]),)
1417+
return (petsc_call('VecPlaceArray', [xlocal, field_from_ptr]),)
1418+
1419+
def reset_array(self, target):
1420+
"""
1421+
"""
1422+
sobjs = self.sobjs
1423+
xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal'])
1424+
return (
1425+
petsc_call('VecResetArray', [xlocal])
1426+
)
14061427

14071428
def assign_time_iters(self, struct):
14081429
return []
@@ -1530,15 +1551,14 @@ def replace_array(self, target):
15301551

15311552
caster = cast(target.dtype, '*')
15321553
return (
1533-
petsc_call('VecGetSize', [xlocal, Byref(sobjs['localsize'])]),
15341554
DummyExpr(
15351555
start_ptr,
15361556
caster(
15371557
FieldFromPointer(target._C_field_data, target._C_symbol)
15381558
) + Mul(target_time, sobjs['localsize']),
15391559
init=True
15401560
),
1541-
petsc_call('VecReplaceArray', [xlocal, start_ptr])
1561+
petsc_call('VecPlaceArray', [xlocal, start_ptr])
15421562
)
15431563
return super().replace_array(target)
15441564

devito/petsc/solve.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,29 +70,33 @@ def build_function_eqns(self, eq, target, arrays):
7070
formfunc = self.make_formfunc(eq, F_target, arrays, targets)
7171
formrhs = self.make_rhs(eq, b, arrays)
7272

73-
return tuple(expr.subs(self.time_mapper) for expr in (formfunc, formrhs))
73+
return (formfunc, formrhs)
7474

7575
def build_matvec_eqns(self, eq, target, arrays):
7676
b, F_target, targets = separate_eqn(eq, target)
7777
if not F_target:
7878
return None
7979
matvec = self.make_matvec(eq, F_target, arrays, targets)
80-
return matvec.subs(self.time_mapper)
80+
return matvec
8181

8282
def make_matvec(self, eq, F_target, arrays, targets):
83-
rhs = arrays['x'] if isinstance(eq, EssentialBC) else F_target.subs(
84-
targets_to_arrays(arrays['x'], targets)
85-
)
83+
if isinstance(eq, EssentialBC):
84+
rhs = arrays['x']
85+
else:
86+
rhs = F_target.subs(targets_to_arrays(arrays['x'], targets))
87+
rhs = rhs.subs(self.time_mapper)
8688
return Eq(arrays['y'], rhs, subdomain=eq.subdomain)
8789

8890
def make_formfunc(self, eq, F_target, arrays, targets):
89-
rhs = 0. if isinstance(eq, EssentialBC) else F_target.subs(
90-
targets_to_arrays(arrays['x'], targets)
91-
)
91+
if isinstance(eq, EssentialBC):
92+
rhs = 0.
93+
else:
94+
rhs = F_target.subs(targets_to_arrays(arrays['x'], targets))
95+
rhs = rhs.subs(self.time_mapper)
9296
return Eq(arrays['f'], rhs, subdomain=eq.subdomain)
9397

9498
def make_rhs(self, eq, b, arrays):
95-
rhs = 0. if isinstance(eq, EssentialBC) else b
99+
rhs = 0. if isinstance(eq, EssentialBC) else b.subs(self.time_mapper)
96100
return Eq(arrays['b'], rhs, subdomain=eq.subdomain)
97101

98102
def generate_arrays(self, target):

devito/petsc/types/object.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,14 @@ def _C_free_priority(self):
6666
return 2
6767

6868

69-
class LocalVec(LocalObject):
69+
class CallbackVec(LocalObject):
7070
"""
71-
PETSc local vector object (Vec).
72-
A local vector has ghost locations that contain values that are
73-
owned by other MPI ranks.
71+
PETSc vector object (Vec).
7472
"""
7573
dtype = CustomDtype('Vec')
7674

7775

78-
class CallbackGlobalVec(LocalVec):
79-
"""
80-
PETSc global vector object (Vec). For example, used for coupled
81-
solves inside the `WholeFormFunc` callback.
82-
"""
83-
84-
85-
class GlobalVec(LocalVec):
86-
"""
87-
PETSc global vector object (Vec).
88-
A global vector is a parallel vector that has no duplicate values
89-
between MPI ranks. A global vector has no ghost locations.
90-
"""
76+
class Vec(CallbackVec):
9177
@property
9278
def _C_free(self):
9379
return petsc_call('VecDestroy', [Byref(self.function)])

devito/petsc/utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,25 @@ def get_petsc_variables():
7676

7777
petsc_variables = get_petsc_variables()
7878

79-
# TODO: Check to see whether Petsc is compiled with
80-
# 32-bit or 64-bit integers
81-
# TODO: Check whether PetscScalar is a float or double
82-
# and only map the right one
83-
petsc_type_mappings = {ctypes.c_int: 'PetscInt',
84-
ctypes.c_float: 'PetscScalar',
85-
ctypes.c_double: 'PetscScalar'}
79+
80+
def get_petsc_type_mappings():
81+
try:
82+
petsc_precision = petsc_variables['PETSC_PRECISION']
83+
except KeyError:
84+
mapper = {}
85+
else:
86+
petsc_scalar = 'PetscScalar'
87+
# TODO: Check to see whether Petsc is compiled with
88+
# 32-bit or 64-bit integers
89+
mapper = {ctypes.c_int: 'PetscInt'}
90+
91+
if petsc_precision == 'single':
92+
mapper[ctypes.c_float] = petsc_scalar
93+
elif petsc_precision == 'double':
94+
mapper[ctypes.c_double] = petsc_scalar
95+
return mapper
96+
97+
98+
petsc_type_mappings = get_petsc_type_mappings()
8699

87100
petsc_languages = ['petsc']

tests/test_petsc.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
FindNodes, retrieve_iteration_tree)
1010
from devito.types import Constant, LocalCompositeObject
1111
from devito.passes.iet.languages.C import CDataManager
12-
from devito.petsc.types import (DM, Mat, LocalVec, PetscMPIInt, KSP,
12+
from devito.petsc.types import (DM, Mat, Vec, PetscMPIInt, KSP,
1313
PC, KSPConvergedReason, PETScArray,
1414
LinearSolveExpr, FieldData, MultipleFieldData)
1515
from devito.petsc.solve import PETScSolve, separate_eqn, centre_stencil
@@ -33,7 +33,7 @@ def test_petsc_local_object():
3333
"""
3434
lo0 = DM('da', stencil_width=1)
3535
lo1 = Mat('A')
36-
lo2 = LocalVec('x')
36+
lo2 = Vec('x')
3737
lo3 = PetscMPIInt('size')
3838
lo4 = KSP('ksp')
3939
lo5 = PC('pc')
@@ -590,7 +590,7 @@ def test_apply():
590590

591591
pn = Function(name='pn', grid=grid, space_order=2)
592592
rhs = Function(name='rhs', grid=grid, space_order=2)
593-
mu = Constant(name='mu', value=2.0)
593+
mu = Constant(name='mu', value=2.0, dtype=np.float64)
594594

595595
eqn = Eq(pn.laplace*mu, rhs, subdomain=grid.interior)
596596

@@ -604,7 +604,7 @@ def test_apply():
604604
op.apply()
605605

606606
# Verify that users can override `mu`
607-
mu_new = Constant(name='mu_new', value=4.0)
607+
mu_new = Constant(name='mu_new', value=4.0, dtype=np.float64)
608608
op.apply(mu=mu_new)
609609

610610

@@ -627,9 +627,10 @@ def test_petsc_frees():
627627
# Check the frees appear in the following order
628628
assert str(frees[0]) == 'PetscCall(VecDestroy(&(bglobal0)));'
629629
assert str(frees[1]) == 'PetscCall(VecDestroy(&(xglobal0)));'
630-
assert str(frees[2]) == 'PetscCall(MatDestroy(&(J0)));'
631-
assert str(frees[3]) == 'PetscCall(SNESDestroy(&(snes0)));'
632-
assert str(frees[4]) == 'PetscCall(DMDestroy(&(da0)));'
630+
assert str(frees[2]) == 'PetscCall(VecDestroy(&(xlocal0)));'
631+
assert str(frees[3]) == 'PetscCall(MatDestroy(&(J0)));'
632+
assert str(frees[4]) == 'PetscCall(SNESDestroy(&(snes0)));'
633+
assert str(frees[5]) == 'PetscCall(DMDestroy(&(da0)));'
633634

634635

635636
@skipif('petsc')

0 commit comments

Comments
 (0)