Skip to content

Commit d01bdd6

Browse files
EdCauntmloubout
authored andcommitted
tests: Update Re and Im tests to be more comprehensive
1 parent fa4245b commit d01bdd6

4 files changed

Lines changed: 33 additions & 36 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ def directions(self):
4949
@property
5050
def dtype(self):
5151
try:
52-
rhs_dtype = self.rhs.dtype
52+
return infer_dtype({self.lhs.dtype, self.rhs.dtype} - {None})
5353
except AttributeError:
54-
rhs_dtype = None
55-
56-
return infer_dtype({self.lhs.dtype, rhs_dtype} - {None})
54+
return self.lhs.dtype
5755

5856
@property
5957
def state(self):

devito/passes/iet/languages/C.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,9 @@ def _print_ListInitializer(self, expr):
6969
return li
7070

7171
def _print_Re(self, expr):
72-
"""Print an Re as an access into the second entry of a float array."""
7372
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
7473
f'({self._print(expr.args[0])})')
7574

7675
def _print_Im(self, expr):
77-
"""Print an Im as an access into the second entry of a float array."""
7876
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
7977
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
106106

107107
def _print_ImaginaryUnit(self, expr):
108108
return f'1i{self.prec_literal(expr).lower()}'
109-
# return '1i'
110109

111110
def _print_Re(self, expr):
112111
return f'{self._ns}real({self._print(expr.args[0])})'

tests/test_symbolics.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,6 @@ def test_customdtype_complex():
927927

928928

929929
class TestComplexParts:
930-
# TODO: Add a cxx switchconfig
931930
def setup_basic(self, dtype):
932931
grid = Grid(shape=(5,), extent=(4.,))
933932
f = Function(name='f', grid=grid, dtype=dtype)
@@ -937,87 +936,90 @@ def setup_basic(self, dtype):
937936
f_imag = Function(name='f_imag', grid=grid)
938937
return f, f_real, f_imag
939938

940-
def run_operator(self, eqs, cxx):
941-
if cxx:
942-
with switchconfig(language='CXX'):
943-
Operator(eqs)()
944-
else:
939+
def run_operator(self, eqs, language):
940+
with switchconfig(language=language):
945941
Operator(eqs)()
946942

947-
@pytest.mark.parametrize('cxx', [False, True])
948-
def test_printing(self, cxx):
943+
def test_devito_print(self):
944+
f, _, _ = self.setup_basic(np.complex64)
945+
946+
assert str(Re(f)) == 'Re(f(x))'
947+
assert str(Im(f)) == 'Im(f(x))'
948+
949+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
950+
def test_printing(self, language):
949951
f, f_real, f_imag = self.setup_basic(np.complex64)
950952

951953
eq_re = Eq(f_real, Re(f))
952954
eq_im = Eq(f_imag, Im(f))
953955

954-
if cxx:
955-
with switchconfig(language='CXX'):
956-
op = Operator([eq_re, eq_im])
957-
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
958-
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
956+
with switchconfig(language=language):
957+
op = Operator([eq_re, eq_im])
958+
959+
if language in ('CXX', 'CXXopenmp'):
960+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
961+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
959962

960963
else:
961-
op = Operator([eq_re, eq_im])
962964
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
963965
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
964966

965-
@pytest.mark.parametrize('cxx', [False, True])
967+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
966968
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
967-
def test_trivial(self, cxx, dtype):
969+
def test_trivial(self, language, dtype):
968970
f, f_real, f_imag = self.setup_basic(dtype)
969971

970972
eq_re = Eq(f_real, Re(f+1.))
971973
eq_im = Eq(f_imag, Im(f+1.))
972974

973-
self.run_operator([eq_re, eq_im], cxx)
975+
self.run_operator([eq_re, eq_im], language)
974976

975977
rcheck = np.array([2., 3., 4., 5., 6.])
976978
icheck = np.array([12., 11., 10., 9., 8.])
977979
assert np.all(np.isclose(f_real.data, rcheck))
978980
assert np.all(np.isclose(f_imag.data, icheck))
979981

980-
@pytest.mark.parametrize('cxx', [False, True])
982+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
981983
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
982-
def test_trivial_imag(self, cxx, dtype):
984+
def test_trivial_imag(self, language, dtype):
983985
f, f_real, f_imag = self.setup_basic(dtype)
984986

985987
eq_re = Eq(f_real, Re(f+1j))
986988
eq_im = Eq(f_imag, Im(f+1j))
987989

988-
self.run_operator([eq_re, eq_im], cxx)
990+
self.run_operator([eq_re, eq_im], language)
989991

990992
rcheck = np.array([1., 2., 3., 4., 5.])
991993
icheck = np.array([13., 12., 11., 10., 9.])
992994
assert np.all(np.isclose(f_real.data, rcheck))
993995
assert np.all(np.isclose(f_imag.data, icheck))
994996

995-
@pytest.mark.parametrize('cxx', [False, True])
996-
def test_deriv(self, cxx):
997+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
998+
def test_deriv(self, language):
997999
f, f_real, f_imag = self.setup_basic(np.complex64)
9981000

9991001
eq_re = Eq(f_real, Re(f.dx))
10001002
eq_im = Eq(f_imag, Im(f.dx))
10011003

1002-
self.run_operator([eq_re, eq_im], cxx)
1004+
self.run_operator([eq_re, eq_im], language)
10031005

10041006
assert np.all(np.isclose(f_real.data, 1.))
10051007
assert np.all(np.isclose(f_imag.data, -1.))
10061008

1007-
@pytest.mark.parametrize('cxx', [False, True])
1008-
def test_outer_deriv(self, cxx):
1009+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1010+
def test_outer_deriv(self, language):
10091011
f, f_real, f_imag = self.setup_basic(np.complex64)
10101012

10111013
eq_re = Eq(f_real, Re(f).dx)
10121014
eq_im = Eq(f_imag, Im(f).dx)
10131015

1014-
self.run_operator([eq_re, eq_im], cxx)
1016+
self.run_operator([eq_re, eq_im], language)
10151017

10161018
assert np.all(np.isclose(f_real.data, 1.))
10171019
assert np.all(np.isclose(f_imag.data, -1.))
10181020

1019-
@pytest.mark.parametrize('cxx', [False, True])
1020-
def test_mul(self, cxx):
1021+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1022+
def test_mul(self, language):
10211023
grid = Grid(shape=(5,))
10221024

10231025
f = Function(name='f', grid=grid, dtype=np.complex64)
@@ -1037,7 +1039,7 @@ def test_mul(self, cxx):
10371039
eq_fh_re = Eq(fh_re, Re(f*h))
10381040
eq_fh_im = Eq(fh_im, Im(f*h))
10391041

1040-
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
1042+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
10411043

10421044
assert np.all(np.isclose(fg_re.data, 2.))
10431045
assert np.all(np.isclose(fg_im.data, 2.))

0 commit comments

Comments
 (0)