Skip to content

Commit f9d1e34

Browse files
committed
compiler: Add PETSc module
1 parent adbf3bf commit f9d1e34

34 files changed

Lines changed: 3067 additions & 39 deletions

.github/workflows/pytest-petsc.yml

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
name: CI-petsc
2+
3+
concurrency:
4+
group: ${{ github.workflow }}-${{ github.ref }}
5+
cancel-in-progress: true
6+
7+
on:
8+
# Trigger the workflow on push or pull request,
9+
# but only for the master branch
10+
push:
11+
branches:
12+
- master
13+
pull_request:
14+
branches:
15+
- master
16+
- FieldFromPointer
17+
18+
jobs:
19+
pytest:
20+
name: ${{ matrix.name }}-${{ matrix.set }}
21+
runs-on: "${{ matrix.os }}"
22+
23+
env:
24+
DOCKER_BUILDKIT: "1"
25+
DEVITO_ARCH: "${{ matrix.arch }}"
26+
DEVITO_LANGUAGE: ${{ matrix.language }}
27+
28+
strategy:
29+
# Prevent all build to stop if a single one fails
30+
fail-fast: false
31+
32+
matrix:
33+
name: [
34+
pytest-docker-py39-gcc-noomp
35+
]
36+
include:
37+
- name: pytest-docker-py39-gcc-noomp
38+
python-version: '3.9'
39+
os: ubuntu-latest
40+
arch: "gcc"
41+
language: "C"
42+
sympy: "1.12"
43+
44+
steps:
45+
- name: Checkout devito
46+
uses: actions/checkout@v4
47+
48+
- name: Build docker image
49+
run: |
50+
docker build . --file docker/Dockerfile.devito --tag devito_img --build-arg base=zoeleibowitz/bases:cpu-${{ matrix.arch }} --build-arg petscinstall=petsc
51+
52+
- name: Set run prefix
53+
run: |
54+
echo "RUN_CMD=docker run --rm -t -e CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }} --name testrun devito_img" >> $GITHUB_ENV
55+
id: set-run
56+
57+
- name: Set tests
58+
run : |
59+
echo "TESTS=tests/test_petsc.py" >> $GITHUB_ENV
60+
id: set-tests
61+
62+
- name: Check configuration
63+
run: |
64+
${{ env.RUN_CMD }} python3 -c "from devito import configuration; print(''.join(['%s: %s \n' % (k, v) for (k, v) in configuration.items()]))"
65+
66+
- name: Test with pytest
67+
run: |
68+
${{ env.RUN_CMD }} pytest --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
69+
70+
- name: Upload coverage to Codecov
71+
if: "!contains(matrix.name, 'docker')"
72+
uses: codecov/codecov-action@v4
73+
with:
74+
token: ${{ secrets.CODECOV_TOKEN }}
75+
name: ${{ matrix.name }}

conftest.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from devito.ir.iet import (FindNodes, FindSymbols, Iteration, ParallelBlock,
1515
retrieve_iteration_tree)
1616
from devito.tools import as_tuple
17+
from devito.petsc.utils import get_petsc_dir, get_petsc_arch
1718

1819
try:
1920
from mpi4py import MPI # noqa
@@ -33,7 +34,7 @@ def skipif(items, whole_module=False):
3334
accepted = set()
3435
accepted.update({'device', 'device-C', 'device-openmp', 'device-openacc',
3536
'device-aomp', 'cpu64-icc', 'cpu64-icx', 'cpu64-nvc',
36-
'noadvisor', 'cpu64-arm', 'cpu64-icpx', 'chkpnt'})
37+
'noadvisor', 'cpu64-arm', 'cpu64-icpx', 'chkpnt', 'petsc'})
3738
accepted.update({'nodevice'})
3839
unknown = sorted(set(items) - accepted)
3940
if unknown:
@@ -93,6 +94,19 @@ def skipif(items, whole_module=False):
9394
if i == 'chkpnt' and Revolver is NoopRevolver:
9495
skipit = "pyrevolve not installed"
9596
break
97+
if i == 'petsc':
98+
petsc_dir = get_petsc_dir()
99+
petsc_arch = get_petsc_arch()
100+
if petsc_dir is None or petsc_arch is None:
101+
skipit = "PETSC_DIR or PETSC_ARCH are not set"
102+
break
103+
else:
104+
petsc_installed = os.path.join(
105+
petsc_dir, petsc_arch, 'include', 'petscconf.h'
106+
)
107+
if not os.path.isfile(petsc_installed):
108+
skipit = "PETSc is not installed"
109+
break
96110

97111
if skipit is False:
98112
return pytest.mark.skipif(False, reason='')

devito/ir/equations/equation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
Stencil, detect_io, detect_accesses)
1010
from devito.symbolics import IntDiv, limits_mapper, uxreplace
1111
from devito.tools import Pickable, Tag, frozendict
12-
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
12+
from devito.types import (Eq, Inc, ReduceMax, ReduceMin,
13+
relational_min)
14+
from devito.types.equation import InjectSolveEq
1315

1416
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
15-
'identity_mapper']
17+
'identity_mapper', 'OpInjectSolve']
1618

1719

1820
class IREq(sympy.Eq, Pickable):
@@ -102,7 +104,8 @@ def detect(cls, expr):
102104
reduction_mapper = {
103105
Inc: OpInc,
104106
ReduceMax: OpMax,
105-
ReduceMin: OpMin
107+
ReduceMin: OpMin,
108+
InjectSolveEq: OpInjectSolve
106109
}
107110
try:
108111
return reduction_mapper[type(expr)]
@@ -119,6 +122,7 @@ def detect(cls, expr):
119122
OpInc = Operation('+')
120123
OpMax = Operation('max')
121124
OpMin = Operation('min')
125+
OpInjectSolve = Operation('solve')
122126

123127

124128
identity_mapper = {

devito/ir/iet/algorithms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from devito.ir.iet import (Expression, Increment, Iteration, List, Conditional, SyncSpot,
44
Section, HaloSpot, ExpressionBundle)
55
from devito.tools import timed_pass
6+
from devito.petsc.types import LinearSolveExpr
7+
from devito.petsc.iet.utils import petsc_iet_mapper
68

79
__all__ = ['iet_build']
810

@@ -24,6 +26,8 @@ def iet_build(stree):
2426
for e in i.exprs:
2527
if e.is_Increment:
2628
exprs.append(Increment(e))
29+
elif isinstance(e.rhs, LinearSolveExpr):
30+
exprs.append(petsc_iet_mapper[e.operation](e, operation=e.operation))
2731
else:
2832
exprs.append(Expression(e, operation=e.operation))
2933
body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs)

devito/ir/iet/efunc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import cached_property
22

3-
from devito.ir.iet.nodes import Call, Callable
3+
from devito.ir.iet.nodes import Call, Callable, FixedArgsCallable
44
from devito.ir.iet.utils import derive_parameters
55
from devito.symbolics import uxreplace
66
from devito.tools import as_tuple
@@ -131,7 +131,7 @@ class AsyncCall(Call):
131131
pass
132132

133133

134-
class ThreadCallable(Callable):
134+
class ThreadCallable(FixedArgsCallable):
135135

136136
"""
137137
A Callable executed asynchronously by a thread.

devito/ir/iet/nodes.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
ctypes_to_cstr)
2222
from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed,
2323
Symbol)
24-
from devito.types.object import AbstractObject, LocalObject
24+
from devito.types.object import AbstractObject, LocalObject, LocalCompositeObject
2525

2626
__all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable',
2727
'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section',
@@ -30,7 +30,7 @@
3030
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
3131
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
3232
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
33-
'Using', 'CallableBody', 'Transfer']
33+
'Using', 'CallableBody', 'Transfer', 'Callback', 'FixedArgsCallable']
3434

3535
# First-class IET nodes
3636

@@ -759,6 +759,15 @@ def defines(self):
759759
return self.all_parameters
760760

761761

762+
class FixedArgsCallable(Callable):
763+
764+
"""
765+
A Callable class that enforces a fixed function signature.
766+
"""
767+
768+
pass
769+
770+
762771
class CallableBody(MultiTraversable):
763772

764773
"""
@@ -1037,8 +1046,8 @@ class Dereference(ExprStmt, Node):
10371046
The following cases are supported:
10381047
10391048
* `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
1040-
* `pointer` is an ArrayObject representing a pointer to a C struct, and
1041-
`pointee` is a field in `pointer`.
1049+
* `pointer` is an ArrayObject or CCompositeObject representing a pointer
1050+
to a C struct, and `pointee` is a field in `pointer`.
10421051
* `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
10431052
`pointee` is a Symbol representing the dereferenced value.
10441053
"""
@@ -1061,7 +1070,8 @@ def functions(self):
10611070
def expr_symbols(self):
10621071
ret = []
10631072
if self.pointer.is_Symbol:
1064-
assert issubclass(self.pointer._C_ctype, ctypes._Pointer), \
1073+
assert (isinstance(self.pointer, LocalCompositeObject) or
1074+
issubclass(self.pointer._C_ctype, ctypes._Pointer)), \
10651075
"Scalar dereference must have a pointer ctype"
10661076
ret.extend([self.pointer._C_symbol, self.pointee._C_symbol])
10671077
elif self.pointer.is_PointerArray or self.pointer.is_TempFunction:
@@ -1136,6 +1146,45 @@ def defines(self):
11361146
return tuple(self.parameters)
11371147

11381148

1149+
class Callback(Call):
1150+
"""
1151+
Base class for special callback types.
1152+
1153+
Parameters
1154+
----------
1155+
name : str
1156+
The name of the callback.
1157+
retval : str
1158+
The return type of the callback.
1159+
param_types : str or list of str
1160+
The return type for each argument of the callback.
1161+
1162+
Notes
1163+
-----
1164+
- The reason Callback is an IET type rather than a SymPy type is
1165+
due to the fact that, when represented at the SymPy level, the IET
1166+
engine fails to bind the callback to a specific Call. Consequently,
1167+
errors occur during the creation of the call graph.
1168+
"""
1169+
# TODO: Create a common base class for Call and Callback to avoid
1170+
# having arguments=None here
1171+
def __init__(self, name, retval=None, param_types=None, arguments=None):
1172+
super().__init__(name=name)
1173+
self.retval = retval
1174+
self.param_types = as_tuple(param_types)
1175+
1176+
@property
1177+
def callback_form(self):
1178+
"""
1179+
A string representation of the callback form.
1180+
1181+
Notes
1182+
-----
1183+
To be overridden by subclasses.
1184+
"""
1185+
return
1186+
1187+
11391188
class Section(List):
11401189

11411190
"""

devito/ir/iet/visitors.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _gen_value(self, obj, mode=1, masked=()):
265265
strtype = f'{strtype}{self._restrict_keyword}'
266266
strtype = ' '.join(qualifiers + [strtype])
267267

268-
if obj.is_LocalObject and obj._C_modifier is not None and mode == 2:
268+
if obj.is_LocalType and obj._C_modifier is not None and mode == 2:
269269
strtype += obj._C_modifier
270270

271271
strname = obj._C_name
@@ -644,6 +644,9 @@ def visit_Lambda(self, o):
644644
top = c.Line(f"[{', '.join(captures)}]({', '.join(decls)}){''.join(extra)}")
645645
return LambdaCollection([top, c.Block(body)])
646646

647+
def visit_Callback(self, o, nested_call=False):
648+
return CallbackArg(o)
649+
647650
def visit_HaloSpot(self, o):
648651
body = flatten(self._visit(i) for i in o.children)
649652
return c.Collection(body)
@@ -1469,3 +1472,12 @@ def sorted_efuncs(efuncs):
14691472
CommCallable: 1
14701473
}
14711474
return sorted_priority(efuncs, priority)
1475+
1476+
1477+
class CallbackArg(c.Generable):
1478+
1479+
def __init__(self, callback):
1480+
self.callback = callback
1481+
1482+
def generate(self):
1483+
yield self.callback.callback_form

devito/operator/operator.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3636
disk_layer)
3737
from devito.types.dimension import Thickness
38-
38+
from devito.petsc.iet.passes import lower_petsc
39+
from devito.petsc.clusters import petsc_preprocess
3940

4041
__all__ = ['Operator']
4142

@@ -264,6 +265,9 @@ def _lower(cls, expressions, **kwargs):
264265
kwargs.setdefault('langbb', cls._Target.langbb())
265266
kwargs.setdefault('printer', cls._Target.Printer)
266267

268+
# TODO: To be updated based on changes in #2509
269+
kwargs.setdefault('concretize_mapper', {})
270+
267271
expressions = as_tuple(expressions)
268272

269273
# Enable recursive lowering
@@ -381,6 +385,9 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs):
381385
# Build a sequence of Clusters from a sequence of Eqs
382386
clusters = clusterize(expressions, **kwargs)
383387

388+
# Preprocess clusters for PETSc lowering
389+
clusters = petsc_preprocess(clusters)
390+
384391
# Operation count before specialization
385392
init_ops = sum(estimate_cost(c.exprs) for c in clusters if c.is_dense)
386393

@@ -478,6 +485,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
478485

479486
# Lower IET to a target-specific IET
480487
graph = Graph(iet, **kwargs)
488+
489+
lower_petsc(graph, **kwargs)
490+
481491
graph = cls._specialize_iet(graph, **kwargs)
482492

483493
# Instrument the IET for C-level profiling
@@ -512,7 +522,7 @@ def dimensions(self):
512522

513523
# During compilation other Dimensions may have been produced
514524
dimensions = FindSymbols('dimensions').visit(self)
515-
ret.update(d for d in dimensions if d.is_PerfKnob)
525+
ret.update(dimensions)
516526

517527
ret = tuple(sorted(ret, key=attrgetter('name')))
518528

0 commit comments

Comments
 (0)