Skip to content

Commit 7489046

Browse files
EdCauntmloubout
authored andcommitted
compiler: Improve ComplexPart printing, adjust CI for coverage
1 parent dd79516 commit 7489046

5 files changed

Lines changed: 32 additions & 53 deletions

File tree

.github/workflows/pytest-core-nompi.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ jobs:
3030

3131
matrix:
3232
name: [
33-
pytest-ubuntu-py311-gcc11-noomp,
34-
pytest-ubuntu-py312-gcc12-omp,
33+
pytest-ubuntu-py311-gcc11-cxxnoomp,
34+
pytest-ubuntu-py312-gcc12-cxxomp,
3535
pytest-ubuntu-py39-gcc14-omp,
3636
pytest-ubuntu-py310-gcc10-noomp,
3737
pytest-ubuntu-py312-gcc13-omp,
@@ -42,18 +42,18 @@ jobs:
4242
]
4343
set: [base, adjoint]
4444
include:
45-
- name: pytest-ubuntu-py311-gcc11-noomp
45+
- name: pytest-ubuntu-py311-gcc11-cxxnoomp
4646
python-version: '3.11'
4747
os: ubuntu-22.04
4848
arch: "gcc-11"
49-
language: "C"
49+
language: "CXX"
5050
sympy: "1.11"
5151

52-
- name: pytest-ubuntu-py312-gcc12-omp
52+
- name: pytest-ubuntu-py312-gcc12-cxxomp
5353
python-version: '3.12'
5454
os: ubuntu-24.04
5555
arch: "gcc-12"
56-
language: "openmp"
56+
language: "CXXopenmp"
5757
sympy: "1.13"
5858

5959
- name: pytest-ubuntu-py39-gcc14-omp

devito/finite_differences/differentiable.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ def __str__(self):
646646

647647
class ComplexPart(Differentiable, sympy.core.function.Application):
648648
"""Abstract class for `Real`, `Imag`, or `Conj` of an expression"""
649+
_name = None
649650

650651
def __new__(cls, *args, **kwargs):
651652
if len(args) != 1:
@@ -669,17 +670,17 @@ def __str__(self):
669670

670671
class Real(ComplexPart):
671672
"""Get the real part of an expression"""
672-
pass
673+
_name = 'real'
673674

674675

675676
class Imag(ComplexPart):
676677
"""Get the imaginary part of an expression"""
677-
pass
678+
_name = 'imag'
678679

679680

680681
class Conj(ComplexPart):
681682
"""Get the complex conjugate of an expression"""
682-
pass
683+
_name = 'conj'
683684

684685

685686
class IndexSum(sympy.Expr, Evaluable):

devito/passes/iet/languages/C.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,10 @@ def _print_ListInitializer(self, expr):
6868
else:
6969
return li
7070

71-
def _print_Real(self, expr):
72-
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
73-
f'({self._print(expr.args[0])})')
74-
75-
def _print_Imag(self, expr):
76-
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
71+
def _print_ComplexPart(self, expr):
72+
return (f'{self.func_prefix(expr)}{expr._name}{self.func_literal(expr)}'
7773
f'({self._print(expr.args[0])})')
7874

7975
def _print_Conj(self, expr):
80-
return (f'conj{self.func_literal(expr).lower()}'
81-
f'({self._print(expr.args[0])})')
76+
# In C, conj is not preceeded by the func_prefix
77+
return (f'conj{self.func_literal(expr)}({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,8 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
107107
def _print_ImaginaryUnit(self, expr):
108108
return f'1i{self.prec_literal(expr).lower()}'
109109

110-
def _print_Real(self, expr):
111-
return f'{self._ns}real({self._print(expr.args[0])})'
112-
113-
def _print_Imag(self, expr):
114-
return f'{self._ns}imag({self._print(expr.args[0])})'
115-
116-
def _print_Conj(self, expr):
117-
return f'{self._ns}conj({self._print(expr.args[0])})'
110+
def _print_ComplexPart(self, expr):
111+
return f'{self._ns}{expr._name}({self._print(expr.args[0])})'
118112

119113
def _print_Cast(self, expr):
120114
# The CXX recommended way to cast is to use static_cast

tests/test_symbolics.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max, Real, Imag, Conj, SubDomain)
10+
Min, Max, Real, Imag, Conj, SubDomain, configuration)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -936,90 +936,79 @@ def setup_basic(self, dtype):
936936
f_imag = Function(name='f_imag', grid=grid)
937937
return f, f_real, f_imag
938938

939-
def run_operator(self, eqs, language):
940-
with switchconfig(language=language):
941-
Operator(eqs)()
942-
943939
def test_devito_print(self):
944940
f, _, _ = self.setup_basic(np.complex64)
945941

946942
assert str(Real(f)) == 'Real(f(x))'
947943
assert str(Imag(f)) == 'Imag(f(x))'
948944

949-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
950-
def test_printing(self, language):
945+
def test_printing(self):
951946
f, f_real, f_imag = self.setup_basic(np.complex64)
952947

953948
eq_re = Eq(f_real, Real(f))
954949
eq_im = Eq(f_imag, Imag(f))
955950

956-
with switchconfig(language=language):
957-
op = Operator([eq_re, eq_im])
951+
op = Operator([eq_re, eq_im])
958952

959-
if language in ('CXX', 'CXXopenmp'):
953+
if configuration['language'] in ('CXX', 'CXXopenmp'):
960954
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
961955
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
962956

963957
else:
964958
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
965959
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
966960

967-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
968961
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
969-
def test_trivial(self, language, dtype):
962+
def test_trivial(self, dtype):
970963
f, f_real, f_imag = self.setup_basic(dtype)
971964

972965
eq_re = Eq(f_real, Real(f+1.))
973966
eq_im = Eq(f_imag, Imag(f+1.))
974967

975-
self.run_operator([eq_re, eq_im], language)
968+
Operator([eq_re, eq_im])()
976969

977970
rcheck = np.array([2., 3., 4., 5., 6.])
978971
icheck = np.array([12., 11., 10., 9., 8.])
979972
assert np.all(np.isclose(f_real.data, rcheck))
980973
assert np.all(np.isclose(f_imag.data, icheck))
981974

982-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
983975
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
984-
def test_trivial_imag(self, language, dtype):
976+
def test_trivial_imag(self, dtype):
985977
f, f_real, f_imag = self.setup_basic(dtype)
986978

987979
eq_re = Eq(f_real, Real(f+1j))
988980
eq_im = Eq(f_imag, Imag(f+1j))
989981

990-
self.run_operator([eq_re, eq_im], language)
982+
Operator([eq_re, eq_im])()
991983

992984
rcheck = np.array([1., 2., 3., 4., 5.])
993985
icheck = np.array([13., 12., 11., 10., 9.])
994986
assert np.all(np.isclose(f_real.data, rcheck))
995987
assert np.all(np.isclose(f_imag.data, icheck))
996988

997-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
998-
def test_deriv(self, language):
989+
def test_deriv(self):
999990
f, f_real, f_imag = self.setup_basic(np.complex64)
1000991

1001992
eq_re = Eq(f_real, Real(f.dx))
1002993
eq_im = Eq(f_imag, Imag(f.dx))
1003994

1004-
self.run_operator([eq_re, eq_im], language)
995+
Operator([eq_re, eq_im])()
1005996

1006997
assert np.all(np.isclose(f_real.data, 1.))
1007998
assert np.all(np.isclose(f_imag.data, -1.))
1008999

1009-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1010-
def test_outer_deriv(self, language):
1000+
def test_outer_deriv(self):
10111001
f, f_real, f_imag = self.setup_basic(np.complex64)
10121002

10131003
eq_re = Eq(f_real, Real(f).dx)
10141004
eq_im = Eq(f_imag, Imag(f).dx)
10151005

1016-
self.run_operator([eq_re, eq_im], language)
1006+
Operator([eq_re, eq_im])()
10171007

10181008
assert np.all(np.isclose(f_real.data, 1.))
10191009
assert np.all(np.isclose(f_imag.data, -1.))
10201010

1021-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1022-
def test_mul(self, language):
1011+
def test_mul(self):
10231012
grid = Grid(shape=(5,))
10241013

10251014
f = Function(name='f', grid=grid, dtype=np.complex64)
@@ -1039,22 +1028,21 @@ def test_mul(self, language):
10391028
eq_fh_re = Eq(fh_re, Real(f*h))
10401029
eq_fh_im = Eq(fh_im, Imag(f*h))
10411030

1042-
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
1031+
Operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im])()
10431032

10441033
assert np.all(np.isclose(fg_re.data, 2.))
10451034
assert np.all(np.isclose(fg_im.data, 2.))
10461035

10471036
assert np.all(np.isclose(fh_re.data, -2.))
10481037
assert np.all(np.isclose(fh_im.data, 2.))
10491038

1050-
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1051-
def test_conj(self, language):
1039+
def test_conj(self):
10521040
grid = Grid(shape=(5,))
10531041
f = Function(name='f', grid=grid, dtype=np.complex64)
10541042
g = Function(name='g', grid=grid, dtype=np.complex64)
10551043

10561044
f.data[:] = np.arange(5) + 1j*np.arange(5)[::-1]
10571045

1058-
self.run_operator([Eq(g, Conj(f))], language)
1046+
Operator([Eq(g, Conj(f))])()
10591047

10601048
assert np.all(np.isclose(g.data, np.conj(f.data)))

0 commit comments

Comments
 (0)