Skip to content

Commit 630ae40

Browse files
authored
Merge pull request #2564 from devitocodes/real-imag
dsl: Add Real, Imag, and Conj operators
2 parents a4d4bd4 + 296f0bf commit 630ae40

29 files changed

Lines changed: 400 additions & 98 deletions

.github/workflows/pytest-core-nompi.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ jobs:
3030

3131
matrix:
3232
name: [
33-
pytest-ubuntu-py311-gcc11-noomp,
34-
pytest-ubuntu-py312-gcc12-omp,
33+
pytest-ubuntu-py311-gcc11-cxxnoomp,
34+
pytest-ubuntu-py312-gcc12-cxxomp,
3535
pytest-ubuntu-py39-gcc14-omp,
3636
pytest-ubuntu-py310-gcc10-noomp,
3737
pytest-ubuntu-py312-gcc13-omp,
@@ -42,18 +42,18 @@ jobs:
4242
]
4343
set: [base, adjoint]
4444
include:
45-
- name: pytest-ubuntu-py311-gcc11-noomp
45+
- name: pytest-ubuntu-py311-gcc11-cxxnoomp
4646
python-version: '3.11'
4747
os: ubuntu-22.04
4848
arch: "gcc-11"
49-
language: "C"
49+
language: "CXX"
5050
sympy: "1.11"
5151

52-
- name: pytest-ubuntu-py312-gcc12-omp
52+
- name: pytest-ubuntu-py312-gcc12-cxxomp
5353
python-version: '3.12'
5454
os: ubuntu-24.04
5555
arch: "gcc-12"
56-
language: "openmp"
56+
language: "CXXopenmp"
5757
sympy: "1.13"
5858

5959
- name: pytest-ubuntu-py39-gcc14-omp

devito/core/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,64 +26,82 @@
2626
DeviceNoopOmpOperator, DeviceNoopAccOperator,
2727
DeviceAdvOmpOperator, DeviceAdvAccOperator,
2828
DeviceFsgOmpOperator, DeviceFsgAccOperator,
29-
DeviceCustomOmpOperator, DeviceCustomAccOperator
29+
DeviceCustomOmpOperator, DeviceCustomAccOperator,
30+
DeviceCustomCXXOmpOperator, DeviceNoopCXXOmpOperator,
31+
DeviceAdvCXXOmpOperator, DeviceFsgCXXOmpOperator
3032
)
3133
from devito.operator.registry import operator_registry
3234

3335
# Register CPU Operators
3436
operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'C')
3537
operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'openmp')
38+
operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'Copenmp')
3639
operator_registry.add(Cpu64CustomCXXOperator, Cpu64, 'custom', 'CXX')
3740
operator_registry.add(Cpu64CustomCXXOperator, Cpu64, 'custom', 'CXXopenmp')
3841

3942
operator_registry.add(Cpu64NoopCOperator, Cpu64, 'noop', 'C')
4043
operator_registry.add(Cpu64NoopOmpOperator, Cpu64, 'noop', 'openmp')
44+
operator_registry.add(Cpu64NoopOmpOperator, Cpu64, 'noop', 'Copenmp')
4145
operator_registry.add(Cpu64CXXNoopCOperator, Cpu64, 'noop', 'CXX')
4246
operator_registry.add(Cpu64CXXNoopOmpOperator, Cpu64, 'noop', 'CXXopenmp')
4347

4448
operator_registry.add(Cpu64AdvCOperator, Cpu64, 'advanced', 'C')
4549
operator_registry.add(Cpu64AdvOmpOperator, Cpu64, 'advanced', 'openmp')
50+
operator_registry.add(Cpu64AdvOmpOperator, Cpu64, 'advanced', 'Copenmp')
4651
operator_registry.add(Cpu64AdvCXXOperator, Cpu64, 'advanced', 'CXX')
4752
operator_registry.add(Cpu64AdvCXXOmpOperator, Cpu64, 'advanced', 'CXXopenmp')
4853

4954
operator_registry.add(Cpu64FsgCOperator, Cpu64, 'advanced-fsg', 'C')
5055
operator_registry.add(Cpu64FsgOmpOperator, Cpu64, 'advanced-fsg', 'openmp')
56+
operator_registry.add(Cpu64FsgOmpOperator, Cpu64, 'advanced-fsg', 'Copenmp')
5157
operator_registry.add(Cpu64FsgCXXOperator, Cpu64, 'advanced-fsg', 'CXX')
5258
operator_registry.add(Cpu64FsgCXXOmpOperator, Cpu64, 'advanced-fsg', 'CXXopenmp')
5359

5460
operator_registry.add(Intel64AdvCOperator, Intel64, 'advanced', 'C')
5561
operator_registry.add(Intel64AdvOmpOperator, Intel64, 'advanced', 'openmp')
62+
operator_registry.add(Intel64AdvOmpOperator, Intel64, 'advanced', 'Copenmp')
5663
operator_registry.add(Intel64CXXAdvCOperator, Intel64, 'advanced', 'CXX')
5764
operator_registry.add(Intel64AdvCXXOmpOperator, Intel64, 'advanced', 'CXXopenmp')
5865

5966
operator_registry.add(Intel64FsgCOperator, Intel64, 'advanced-fsg', 'C')
6067
operator_registry.add(Intel64FsgOmpOperator, Intel64, 'advanced-fsg', 'openmp')
68+
operator_registry.add(Intel64FsgOmpOperator, Intel64, 'advanced-fsg', 'Copenmp')
6169
operator_registry.add(Intel64FsgCXXOperator, Intel64, 'advanced-fsg', 'CXX')
6270
operator_registry.add(Intel64FsgCXXOmpOperator, Intel64, 'advanced-fsg', 'CXXopenmp')
6371

6472
operator_registry.add(ArmAdvCOperator, Arm, 'advanced', 'C')
6573
operator_registry.add(ArmAdvOmpOperator, Arm, 'advanced', 'openmp')
74+
operator_registry.add(ArmAdvOmpOperator, Arm, 'advanced', 'Copenmp')
6675
operator_registry.add(ArmAdvCXXOperator, Arm, 'advanced', 'CXX')
6776
operator_registry.add(ArmAdvCXXOmpOperator, Arm, 'advanced', 'CXXopenmp')
6877

6978
operator_registry.add(PowerAdvCOperator, Power, 'advanced', 'C')
7079
operator_registry.add(PowerAdvOmpOperator, Power, 'advanced', 'openmp')
80+
operator_registry.add(PowerAdvOmpOperator, Power, 'advanced', 'Copenmp')
7181
operator_registry.add(PowerCXXAdvCOperator, Power, 'advanced', 'CXX')
7282
operator_registry.add(PowerAdvCXXOmpOperator, Power, 'advanced', 'CXXopenmp')
7383

7484
# Register Device Operators
7585
operator_registry.add(DeviceCustomOmpOperator, Device, 'custom', 'C')
7686
operator_registry.add(DeviceCustomOmpOperator, Device, 'custom', 'openmp')
87+
operator_registry.add(DeviceCustomCXXOmpOperator, Device, 'custom', 'CXX')
88+
operator_registry.add(DeviceCustomCXXOmpOperator, Device, 'custom', 'CXXopenmp')
7789
operator_registry.add(DeviceCustomAccOperator, Device, 'custom', 'openacc')
7890

7991
operator_registry.add(DeviceNoopOmpOperator, Device, 'noop', 'C')
8092
operator_registry.add(DeviceNoopOmpOperator, Device, 'noop', 'openmp')
93+
operator_registry.add(DeviceNoopCXXOmpOperator, Device, 'noop', 'CXX')
94+
operator_registry.add(DeviceNoopCXXOmpOperator, Device, 'noop', 'CXXopenmp')
8195
operator_registry.add(DeviceNoopAccOperator, Device, 'noop', 'openacc')
8296

8397
operator_registry.add(DeviceAdvOmpOperator, Device, 'advanced', 'C')
8498
operator_registry.add(DeviceAdvOmpOperator, Device, 'advanced', 'openmp')
99+
operator_registry.add(DeviceAdvCXXOmpOperator, Device, 'advanced', 'CXX')
100+
operator_registry.add(DeviceAdvCXXOmpOperator, Device, 'advanced', 'CXXopenmp')
85101
operator_registry.add(DeviceAdvAccOperator, Device, 'advanced', 'openacc')
86102

87103
operator_registry.add(DeviceFsgOmpOperator, Device, 'advanced-fsg', 'C')
88104
operator_registry.add(DeviceFsgOmpOperator, Device, 'advanced-fsg', 'openmp')
105+
operator_registry.add(DeviceFsgCXXOmpOperator, Device, 'advanced-fsg', 'CXX')
106+
operator_registry.add(DeviceFsgCXXOmpOperator, Device, 'advanced-fsg', 'CXXopenmp')
89107
operator_registry.add(DeviceFsgAccOperator, Device, 'advanced-fsg', 'openacc')

devito/core/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
322322

323323
class Cpu64CustomCXXOperator(Cpu64CustomOperator):
324324

325-
_Target = CXXTarget
325+
_Target = CXXOmpTarget
326326
LINEARIZE = True
327327

328328
# Language level

devito/core/gpu.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
from devito.passes.clusters import (Lift, tasking, memcpy_prefetch, blocking,
1111
buffering, cire, cse, factorize, fission, fuse,
1212
optimize_pows)
13-
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize,
14-
hoist_prodders, linearize, pthreadify,
13+
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, DeviceCXXOmpTarget,
14+
mpiize, hoist_prodders, linearize, pthreadify,
1515
relax_incr_dimensions, check_stability)
1616
from devito.tools import as_tuple, timed_pass
1717

1818
__all__ = ['DeviceNoopOperator', 'DeviceAdvOperator', 'DeviceCustomOperator',
1919
'DeviceNoopOmpOperator', 'DeviceAdvOmpOperator', 'DeviceFsgOmpOperator',
2020
'DeviceCustomOmpOperator', 'DeviceNoopAccOperator', 'DeviceAdvAccOperator',
21-
'DeviceFsgAccOperator', 'DeviceCustomAccOperator']
21+
'DeviceFsgAccOperator', 'DeviceCustomAccOperator', 'DeviceNoopCXXOmpOperator',
22+
'DeviceAdvCXXOmpOperator', 'DeviceFsgCXXOmpOperator',
23+
'DeviceCustomCXXOmpOperator']
2224

2325

2426
class DeviceOperatorMixin:
@@ -364,14 +366,29 @@ class DeviceNoopOmpOperator(DeviceOmpOperatorMixin, DeviceNoopOperator):
364366
pass
365367

366368

369+
class DeviceNoopCXXOmpOperator(DeviceNoopOmpOperator):
370+
_Target = DeviceCXXOmpTarget
371+
LINEARIZE = True
372+
373+
367374
class DeviceAdvOmpOperator(DeviceOmpOperatorMixin, DeviceAdvOperator):
368375
pass
369376

370377

378+
class DeviceAdvCXXOmpOperator(DeviceAdvOmpOperator):
379+
_Target = DeviceCXXOmpTarget
380+
LINEARIZE = True
381+
382+
371383
class DeviceFsgOmpOperator(DeviceOmpOperatorMixin, DeviceFsgOperator):
372384
pass
373385

374386

387+
class DeviceFsgCXXOmpOperator(DeviceFsgOmpOperator):
388+
_Target = DeviceCXXOmpTarget
389+
LINEARIZE = True
390+
391+
375392
class DeviceCustomOmpOperator(DeviceOmpOperatorMixin, DeviceCustomOperator):
376393

377394
_known_passes = DeviceCustomOperator._known_passes + ('openmp',)
@@ -384,6 +401,11 @@ def _make_iet_passes_mapper(cls, **kwargs):
384401
return mapper
385402

386403

404+
class DeviceCustomCXXOmpOperator(DeviceCustomOmpOperator):
405+
_Target = DeviceCXXOmpTarget
406+
LINEARIZE = True
407+
408+
387409
# OpenACC
388410

389411
class DeviceAccOperatorMixin:

devito/finite_differences/differentiable.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from devito.finite_differences.tools import make_shift_x0, coeff_priority
1818
from devito.logger import warning
1919
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
20-
infer_dtype, is_integer, split, is_number)
20+
infer_dtype, extract_dtype, is_integer, split, is_number)
2121
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights']
25+
'Weights', 'Real', 'Imag', 'Conj']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -644,6 +644,46 @@ def __str__(self):
644644
__repr__ = __str__
645645

646646

647+
class ComplexPart(Differentiable, sympy.core.function.Application):
648+
"""Abstract class for `Real`, `Imag`, or `Conj` of an expression"""
649+
_name = None
650+
651+
def __new__(cls, *args, **kwargs):
652+
if len(args) != 1:
653+
raise ValueError(f"{cls.__name__} expects exactly one arg;"
654+
f" {len(args)} were supplied instead.")
655+
656+
return super().__new__(cls, *args, **kwargs)
657+
658+
def __str__(self):
659+
return f"{self.__class__.__name__}({self.args[0]})"
660+
661+
__repr__ = __str__
662+
663+
664+
class RealComplexPart(ComplexPart):
665+
666+
@cached_property
667+
def dtype(self):
668+
dtype = extract_dtype(self)
669+
return dtype(0).real.__class__
670+
671+
672+
class Real(RealComplexPart):
673+
"""Get the real part of an expression"""
674+
_name = 'real'
675+
676+
677+
class Imag(RealComplexPart):
678+
"""Get the imaginary part of an expression"""
679+
_name = 'imag'
680+
681+
682+
class Conj(ComplexPart):
683+
"""Get the complex conjugate of an expression"""
684+
_name = 'conj'
685+
686+
647687
class IndexSum(sympy.Expr, Evaluable):
648688

649689
"""

devito/finite_differences/tools.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,10 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
268268
raise ValueError(f"More weights ({nweights}) provided than the maximum "
269269
f"stencil size ({order + 1}) for order {order} scheme")
270270
elif do > dw:
271+
order = nweights - nweights % 2
271272
warning(f"Less weights ({nweights}) provided than the stencil size"
272273
f"({order + 1}) for order {order} scheme."
273-
" Reducing order to {nweights//2}")
274-
order = nweights - nweights % 2
275-
274+
f" Reducing order to {order}")
276275
# Evaluation point
277276
x0 = sympify(((x0 or {}).get(dim) or expr.indices_ref[dim]))
278277

devito/ir/iet/visitors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ def visit_PointerCast(self, o):
423423
if o.flat is None:
424424
shape = ''.join(f"[{self.ccode(i)}]" for i in o.castshape)
425425
rshape = f'(*){shape}'
426-
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {v}){shape}')
426+
if shape:
427+
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {v}){shape}')
428+
else:
429+
lvalue = c.Value(cstr, f'*{self._restrict_keyword} {v}')
427430
else:
428431
rshape = '*'
429432
lvalue = c.Value(cstr, f'*{v}')

devito/operator/operator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,10 @@ def parse_kwargs(**kwargs):
13371337

13381338
if not opt or isinstance(opt, str):
13391339
mode, options = opt, {}
1340+
# Legacy Operator(..., opt='openmp', ...) support
1341+
if mode == 'openmp':
1342+
mode = 'noop'
1343+
options = {'openmp': True}
13401344
elif isinstance(opt, tuple):
13411345
if len(opt) == 0:
13421346
mode, options = 'noop', {}
@@ -1353,7 +1357,7 @@ def parse_kwargs(**kwargs):
13531357
# `opt`, deprecated kwargs
13541358
kwopenmp = kwargs.get('openmp', options.get('openmp'))
13551359
if kwopenmp is None:
1356-
openmp = kwargs.get('language', configuration['language']) == 'openmp'
1360+
openmp = 'openmp' in kwargs.get('language', configuration['language'])
13571361
else:
13581362
openmp = kwopenmp
13591363

@@ -1399,7 +1403,9 @@ def parse_kwargs(**kwargs):
13991403
kwargs['language'] = language
14001404
elif kwopenmp is not None:
14011405
# Handle deprecated `openmp` kwarg for backward compatibility
1402-
kwargs['language'] = 'openmp' if openmp else 'C'
1406+
omp = {'C': 'openmp', 'CXX': 'CXXopenmp'}.get(configuration['language'],
1407+
'openmp')
1408+
kwargs['language'] = omp if openmp else 'C'
14031409
else:
14041410
kwargs['language'] = configuration['language']
14051411

devito/passes/clusters/aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ def cost(self):
13541354
# Not just the sum for the individual items' cost! There might be
13551355
# redundancies, which we factor out here...
13561356
counter = generator()
1357-
make = lambda: Symbol(name='dummy%d' % counter(), dtype=np.float32)
1357+
make = lambda _: Symbol(name='dummy%d' % counter(), dtype=np.float32)
13581358

13591359
tot = 0
13601360
for v in as_mapper(self, lambda i: i.ispace).values():

devito/passes/clusters/cse.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
1616
from devito.symbolics.manipulation import _uxreplace
17-
from devito.tools import DAG, as_list, as_tuple, frozendict
17+
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
1818
from devito.types import Eq, Symbol, Temp
1919

2020
__all__ = ['cse']
@@ -78,7 +78,8 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
7878
if cluster.is_fence:
7979
return cluster
8080

81-
make = lambda: CTemp(name=sregistry.make_name(), dtype=dtype)
81+
make_dtype = lambda e: np.promote_types(e.dtype, dtype).type
82+
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
8283

8384
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
8485

@@ -118,7 +119,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'):
118119
exprs = maybe_exprs
119120
scope = Scope(maybe_exprs)
120121
else:
121-
exprs = [Eq(make(), e) for e in maybe_exprs]
122+
exprs = [Eq(make(e), e) for e in maybe_exprs]
122123
scope = Scope([])
123124

124125
# Some sub-expressions aren't really "common" -- that's the case of Dimension-
@@ -155,7 +156,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'):
155156
candidates = [c for c in candidates if c.cost == cost]
156157

157158
# Apply replacements
158-
chosen = [(c, scheduled.get(c.key) or make()) for c in candidates]
159+
chosen = [(c, scheduled.get(c.key) or make(c)) for c in candidates]
159160
exprs = _inject(exprs, chosen, scheduled)
160161

161162
# Drop useless temporaries (e.g., r0=r1)
@@ -275,6 +276,10 @@ def __new__(cls, expr, conditionals=None, sources=()):
275276
def expr(self):
276277
return self[0]
277278

279+
@property
280+
def dtype(self):
281+
return extract_dtype(self.expr)
282+
278283
@property
279284
def conditionals(self):
280285
return self[1]

0 commit comments

Comments
 (0)