Skip to content

Commit 5c3be36

Browse files
committed
compiler: Remove macro generation and start work on CXX
1 parent bde894a commit 5c3be36

6 files changed

Lines changed: 56 additions & 48 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def __new__(cls, *args, **kwargs):
658658

659659
if not np.issubdtype(new_args[0].dtype, np.complexfloating):
660660
raise ValueError(f"{cls.__name__} requires a complex dtype,"
661-
f" not {new_args[0].dtype}.")
661+
f" not {new_args[0].dtype.__name__}.")
662662

663663
return super().__new__(cls, *new_args, **kwargs)
664664

devito/ir/cgen/printer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,6 @@ def _print_SafeInv(self, expr):
217217
val = self._print(expr.val)
218218
return f'SAFEINV({val}, {base})'
219219

220-
def _print_Re(self, expr):
221-
"""Print an Re as an access into the second entry of a float array."""
222-
return f"Re({self._print(expr.args[0])})"
223-
224-
def _print_Im(self, expr):
225-
"""Print an Im as an access into the second entry of a float array."""
226-
return f"Im({self._print(expr.args[0])})"
227-
228220
def _print_Mod(self, expr):
229221
"""Print a Mod as a C-like %-based operation."""
230222
args = [f'({self._print(a)})' for a in expr.args]

devito/passes/iet/languages/C.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,13 @@ class CPrinter(BasePrinter, C99CodePrinter):
5555

5656
def _print_ImaginaryUnit(self, expr):
5757
return '_Complex_I'
58+
59+
def _print_Re(self, expr):
60+
"""Print an Re as an access into the second entry of a float array."""
61+
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
62+
f'({self._print(expr.args[0])})')
63+
64+
def _print_Im(self, expr):
65+
"""Print an Im as an access into the second entry of a float array."""
66+
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
67+
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
103103

104104
def _print_ImaginaryUnit(self, expr):
105105
return f'1i{self.prec_literal(expr).lower()}'
106+
# return '1i'
107+
108+
def _print_Re(self, expr):
109+
return f'{self._ns}real({self._print(expr.args[0])})'
110+
111+
def _print_Im(self, expr):
112+
return f'{self._ns}imag({self._print(expr.args[0])})'
106113

107114
def _print_Cast(self, expr):
108115
# The CXX recommended way to cast is to use static_cast

devito/passes/iet/misc.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sympy
66

77
from devito.finite_differences import Max, Min
8-
from devito.finite_differences.differentiable import SafeInv, Re, Im
8+
from devito.finite_differences.differentiable import SafeInv
99
from devito.logger import warning
1010
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
1111
FindApplications, FindNodes, FindSymbols, Transformer,
@@ -245,22 +245,6 @@ def _(expr, langbb):
245245
f'(((a) < {eps}F || ({b}) < {eps}F) ? (0.0F) : ((1.0F) / (a)))'),), {}
246246

247247

248-
@_lower_macro_math.register(Re)
249-
def _(expr, langbb):
250-
return (('Re(x)',
251-
'(_Generic((x), '
252-
'float _Complex : (float *) &(x), '
253-
'double _Complex : (double *) &(x))[0])'),), {}
254-
255-
256-
@_lower_macro_math.register(Im)
257-
def _(expr, langbb):
258-
return (('Im(x)',
259-
'(_Generic((x), '
260-
'float _Complex : (float *) &(x), '
261-
'double _Complex : (double *) &(x))[1])'),), {}
262-
263-
264248
@iet_pass
265249
def minimize_symbols(iet):
266250
"""

tests/test_symbolics.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max, Re, Im)
10+
Min, Max, Re, Im, switchconfig)
1111
from devito.finite_differences.differentiable import SafeInv, Weights
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -821,6 +821,7 @@ def test_assumptions(self, op, expr, assumptions, expected):
821821

822822

823823
class TestComplexParts:
824+
# TODO: Add a cxx switchconfig
824825
def setup_basic(self, dtype):
825826
grid = Grid(shape=(5,), extent=(4.,))
826827
f = Function(name='f', grid=grid, dtype=dtype)
@@ -830,73 +831,87 @@ def setup_basic(self, dtype):
830831
f_imag = Function(name='f_imag', grid=grid)
831832
return f, f_real, f_imag
832833

833-
def test_printing(self):
834+
def run_operator(self, eqs, cxx):
835+
if cxx:
836+
with switchconfig(language='CXX'):
837+
Operator(eqs)()
838+
else:
839+
Operator(eqs)()
840+
841+
@pytest.mark.parametrize('cxx', [False, True])
842+
def test_printing(self, cxx):
834843
f, f_real, f_imag = self.setup_basic(np.complex64)
835844

836845
eq_re = Eq(f_real, Re(f))
837846
eq_im = Eq(f_imag, Im(f))
838847

839-
op = Operator([eq_re, eq_im])
840-
841-
assert ("#define Im(x) (_Generic((x), float _Complex : (float *) &(x), "
842-
"double _Complex : (double *) &(x))[1])") in str(op.ccode)
843-
assert ("#define Re(x) (_Generic((x), float _Complex : (float *) &(x), "
844-
"double _Complex : (double *) &(x))[0])") in str(op.ccode)
848+
if cxx:
849+
with switchconfig(language='CXX'):
850+
op = Operator([eq_re, eq_im])
851+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
852+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
845853

846-
assert "f_real[x + 1] = Re(f[x + 1])" in str(op.ccode)
847-
assert "f_imag[x + 1] = Im(f[x + 1])" in str(op.ccode)
854+
else:
855+
op = Operator([eq_re, eq_im])
856+
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
857+
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
848858

859+
@pytest.mark.parametrize('cxx', [False, True])
849860
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
850-
def test_trivial(self, dtype):
861+
def test_trivial(self, cxx, dtype):
851862
f, f_real, f_imag = self.setup_basic(dtype)
852863

853-
eq_re = Eq(f_real, Re(f+1))
854-
eq_im = Eq(f_imag, Im(f+1))
864+
eq_re = Eq(f_real, Re(f+1.))
865+
eq_im = Eq(f_imag, Im(f+1.))
855866

856-
Operator([eq_re, eq_im])()
867+
self.run_operator([eq_re, eq_im], cxx)
857868

858869
rcheck = np.array([2., 3., 4., 5., 6.])
859870
icheck = np.array([12., 11., 10., 9., 8.])
860871
assert np.all(np.isclose(f_real.data, rcheck))
861872
assert np.all(np.isclose(f_imag.data, icheck))
862873

874+
@pytest.mark.parametrize('cxx', [False, True])
863875
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
864-
def test_trivial_imag(self, dtype):
876+
def test_trivial_imag(self, cxx, dtype):
865877
f, f_real, f_imag = self.setup_basic(dtype)
866878

867879
eq_re = Eq(f_real, Re(f+1j))
868880
eq_im = Eq(f_imag, Im(f+1j))
869881

870-
Operator([eq_re, eq_im])()
882+
self.run_operator([eq_re, eq_im], cxx)
871883

872884
rcheck = np.array([1., 2., 3., 4., 5.])
873885
icheck = np.array([13., 12., 11., 10., 9.])
874886
assert np.all(np.isclose(f_real.data, rcheck))
875887
assert np.all(np.isclose(f_imag.data, icheck))
876888

877-
def test_deriv(self):
889+
@pytest.mark.parametrize('cxx', [False, True])
890+
def test_deriv(self, cxx):
878891
f, f_real, f_imag = self.setup_basic(np.complex64)
879892

880893
eq_re = Eq(f_real, Re(f.dx))
881894
eq_im = Eq(f_imag, Im(f.dx))
882895

883-
Operator([eq_re, eq_im])()
896+
self.run_operator([eq_re, eq_im], cxx)
884897

885898
assert np.all(np.isclose(f_real.data, 1.))
886899
assert np.all(np.isclose(f_imag.data, -1.))
887900

888-
def test_outer_deriv(self):
901+
@pytest.mark.parametrize('cxx', [False, True])
902+
def test_outer_deriv(self, cxx):
889903
f, f_real, f_imag = self.setup_basic(np.complex64)
890904

891905
eq_re = Eq(f_real, Re(f).dx)
892906
eq_im = Eq(f_imag, Im(f).dx)
893907

894-
Operator([eq_re, eq_im])()
908+
self.run_operator([eq_re, eq_im], cxx)
895909

896910
assert np.all(np.isclose(f_real.data, 1.))
897911
assert np.all(np.isclose(f_imag.data, -1.))
898912

899-
def test_mul(self):
913+
@pytest.mark.parametrize('cxx', [False, True])
914+
def test_mul(self, cxx):
900915
grid = Grid(shape=(5,))
901916

902917
f = Function(name='f', grid=grid, dtype=np.complex64)
@@ -916,7 +931,7 @@ def test_mul(self):
916931
eq_fh_re = Eq(fh_re, Re(f*h))
917932
eq_fh_im = Eq(fh_im, Im(f*h))
918933

919-
Operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im])()
934+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
920935

921936
assert np.all(np.isclose(fg_re.data, 2.))
922937
assert np.all(np.isclose(fg_im.data, 2.))

0 commit comments

Comments
 (0)