Skip to content

Commit f6d719b

Browse files
committed
tests: Update Re and Im tests to be more comprehensive
1 parent c864338 commit f6d719b

4 files changed

Lines changed: 33 additions & 36 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ def directions(self):
4949
@property
5050
def dtype(self):
5151
try:
52-
rhs_dtype = self.rhs.dtype
52+
return infer_dtype({self.lhs.dtype, self.rhs.dtype} - {None})
5353
except AttributeError:
54-
rhs_dtype = None
55-
56-
return infer_dtype({self.lhs.dtype, rhs_dtype} - {None})
54+
return self.lhs.dtype
5755

5856
@property
5957
def state(self):

devito/passes/iet/languages/C.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,9 @@ def _print_ListInitializer(self, expr):
6969
return li
7070

7171
def _print_Re(self, expr):
72-
"""Print an Re as an access into the second entry of a float array."""
7372
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
7473
f'({self._print(expr.args[0])})')
7574

7675
def _print_Im(self, expr):
77-
"""Print an Im as an access into the second entry of a float array."""
7876
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
7977
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
106106

107107
def _print_ImaginaryUnit(self, expr):
108108
return f'1i{self.prec_literal(expr).lower()}'
109-
# return '1i'
110109

111110
def _print_Re(self, expr):
112111
return f'{self._ns}real({self._print(expr.args[0])})'

tests/test_symbolics.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,6 @@ def test_print_div():
915915

916916

917917
class TestComplexParts:
918-
# TODO: Add a cxx switchconfig
919918
def setup_basic(self, dtype):
920919
grid = Grid(shape=(5,), extent=(4.,))
921920
f = Function(name='f', grid=grid, dtype=dtype)
@@ -925,87 +924,90 @@ def setup_basic(self, dtype):
925924
f_imag = Function(name='f_imag', grid=grid)
926925
return f, f_real, f_imag
927926

928-
def run_operator(self, eqs, cxx):
929-
if cxx:
930-
with switchconfig(language='CXX'):
931-
Operator(eqs)()
932-
else:
927+
def run_operator(self, eqs, language):
928+
with switchconfig(language=language):
933929
Operator(eqs)()
934930

935-
@pytest.mark.parametrize('cxx', [False, True])
936-
def test_printing(self, cxx):
931+
def test_devito_print(self):
932+
f, _, _ = self.setup_basic(np.complex64)
933+
934+
assert str(Re(f)) == 'Re(f(x))'
935+
assert str(Im(f)) == 'Im(f(x))'
936+
937+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
938+
def test_printing(self, language):
937939
f, f_real, f_imag = self.setup_basic(np.complex64)
938940

939941
eq_re = Eq(f_real, Re(f))
940942
eq_im = Eq(f_imag, Im(f))
941943

942-
if cxx:
943-
with switchconfig(language='CXX'):
944-
op = Operator([eq_re, eq_im])
945-
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
946-
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
944+
with switchconfig(language=language):
945+
op = Operator([eq_re, eq_im])
946+
947+
if language in ('CXX', 'CXXopenmp'):
948+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
949+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
947950

948951
else:
949-
op = Operator([eq_re, eq_im])
950952
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
951953
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
952954

953-
@pytest.mark.parametrize('cxx', [False, True])
955+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
954956
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
955-
def test_trivial(self, cxx, dtype):
957+
def test_trivial(self, language, dtype):
956958
f, f_real, f_imag = self.setup_basic(dtype)
957959

958960
eq_re = Eq(f_real, Re(f+1.))
959961
eq_im = Eq(f_imag, Im(f+1.))
960962

961-
self.run_operator([eq_re, eq_im], cxx)
963+
self.run_operator([eq_re, eq_im], language)
962964

963965
rcheck = np.array([2., 3., 4., 5., 6.])
964966
icheck = np.array([12., 11., 10., 9., 8.])
965967
assert np.all(np.isclose(f_real.data, rcheck))
966968
assert np.all(np.isclose(f_imag.data, icheck))
967969

968-
@pytest.mark.parametrize('cxx', [False, True])
970+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
969971
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
970-
def test_trivial_imag(self, cxx, dtype):
972+
def test_trivial_imag(self, language, dtype):
971973
f, f_real, f_imag = self.setup_basic(dtype)
972974

973975
eq_re = Eq(f_real, Re(f+1j))
974976
eq_im = Eq(f_imag, Im(f+1j))
975977

976-
self.run_operator([eq_re, eq_im], cxx)
978+
self.run_operator([eq_re, eq_im], language)
977979

978980
rcheck = np.array([1., 2., 3., 4., 5.])
979981
icheck = np.array([13., 12., 11., 10., 9.])
980982
assert np.all(np.isclose(f_real.data, rcheck))
981983
assert np.all(np.isclose(f_imag.data, icheck))
982984

983-
@pytest.mark.parametrize('cxx', [False, True])
984-
def test_deriv(self, cxx):
985+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
986+
def test_deriv(self, language):
985987
f, f_real, f_imag = self.setup_basic(np.complex64)
986988

987989
eq_re = Eq(f_real, Re(f.dx))
988990
eq_im = Eq(f_imag, Im(f.dx))
989991

990-
self.run_operator([eq_re, eq_im], cxx)
992+
self.run_operator([eq_re, eq_im], language)
991993

992994
assert np.all(np.isclose(f_real.data, 1.))
993995
assert np.all(np.isclose(f_imag.data, -1.))
994996

995-
@pytest.mark.parametrize('cxx', [False, True])
996-
def test_outer_deriv(self, cxx):
997+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
998+
def test_outer_deriv(self, language):
997999
f, f_real, f_imag = self.setup_basic(np.complex64)
9981000

9991001
eq_re = Eq(f_real, Re(f).dx)
10001002
eq_im = Eq(f_imag, Im(f).dx)
10011003

1002-
self.run_operator([eq_re, eq_im], cxx)
1004+
self.run_operator([eq_re, eq_im], language)
10031005

10041006
assert np.all(np.isclose(f_real.data, 1.))
10051007
assert np.all(np.isclose(f_imag.data, -1.))
10061008

1007-
@pytest.mark.parametrize('cxx', [False, True])
1008-
def test_mul(self, cxx):
1009+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1010+
def test_mul(self, language):
10091011
grid = Grid(shape=(5,))
10101012

10111013
f = Function(name='f', grid=grid, dtype=np.complex64)
@@ -1025,7 +1027,7 @@ def test_mul(self, cxx):
10251027
eq_fh_re = Eq(fh_re, Re(f*h))
10261028
eq_fh_im = Eq(fh_im, Im(f*h))
10271029

1028-
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
1030+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
10291031

10301032
assert np.all(np.isclose(fg_re.data, 2.))
10311033
assert np.all(np.isclose(fg_im.data, 2.))

0 commit comments

Comments
 (0)