Skip to content

Commit 7d0f614

Browse files
committed
tests: Update Re and Im tests to be more comprehensive
1 parent 1ed5c1c commit 7d0f614

1 file changed

Lines changed: 25 additions & 28 deletions

File tree

tests/test_symbolics.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -831,87 +831,84 @@ def setup_basic(self, dtype):
831831
f_imag = Function(name='f_imag', grid=grid)
832832
return f, f_real, f_imag
833833

834-
def run_operator(self, eqs, cxx):
835-
if cxx:
836-
with switchconfig(language='CXX'):
837-
Operator(eqs)()
838-
else:
834+
def run_operator(self, eqs, language):
835+
with switchconfig(language=language):
839836
Operator(eqs)()
840837

841-
@pytest.mark.parametrize('cxx', [False, True])
842-
def test_printing(self, cxx):
838+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
839+
def test_printing(self, language):
843840
f, f_real, f_imag = self.setup_basic(np.complex64)
844841

845842
eq_re = Eq(f_real, Re(f))
846843
eq_im = Eq(f_imag, Im(f))
847844

848-
if cxx:
849-
with switchconfig(language='CXX'):
850-
op = Operator([eq_re, eq_im])
851-
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
852-
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
845+
with switchconfig(language=language):
846+
op = Operator([eq_re, eq_im])
847+
848+
if language in ('CXX', 'CXXopenmp'):
849+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
850+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
853851

854852
else:
855-
op = Operator([eq_re, eq_im])
856853
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
857854
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
858855

859-
@pytest.mark.parametrize('cxx', [False, True])
856+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
860857
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
861-
def test_trivial(self, cxx, dtype):
858+
def test_trivial(self, language, dtype):
862859
f, f_real, f_imag = self.setup_basic(dtype)
863860

864861
eq_re = Eq(f_real, Re(f+1.))
865862
eq_im = Eq(f_imag, Im(f+1.))
866863

867-
self.run_operator([eq_re, eq_im], cxx)
864+
self.run_operator([eq_re, eq_im], language)
868865

869866
rcheck = np.array([2., 3., 4., 5., 6.])
870867
icheck = np.array([12., 11., 10., 9., 8.])
871868
assert np.all(np.isclose(f_real.data, rcheck))
872869
assert np.all(np.isclose(f_imag.data, icheck))
873870

874-
@pytest.mark.parametrize('cxx', [False, True])
871+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
875872
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
876-
def test_trivial_imag(self, cxx, dtype):
873+
def test_trivial_imag(self, language, dtype):
877874
f, f_real, f_imag = self.setup_basic(dtype)
878875

879876
eq_re = Eq(f_real, Re(f+1j))
880877
eq_im = Eq(f_imag, Im(f+1j))
881878

882-
self.run_operator([eq_re, eq_im], cxx)
879+
self.run_operator([eq_re, eq_im], language)
883880

884881
rcheck = np.array([1., 2., 3., 4., 5.])
885882
icheck = np.array([13., 12., 11., 10., 9.])
886883
assert np.all(np.isclose(f_real.data, rcheck))
887884
assert np.all(np.isclose(f_imag.data, icheck))
888885

889-
@pytest.mark.parametrize('cxx', [False, True])
890-
def test_deriv(self, cxx):
886+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
887+
def test_deriv(self, language):
891888
f, f_real, f_imag = self.setup_basic(np.complex64)
892889

893890
eq_re = Eq(f_real, Re(f.dx))
894891
eq_im = Eq(f_imag, Im(f.dx))
895892

896-
self.run_operator([eq_re, eq_im], cxx)
893+
self.run_operator([eq_re, eq_im], language)
897894

898895
assert np.all(np.isclose(f_real.data, 1.))
899896
assert np.all(np.isclose(f_imag.data, -1.))
900897

901-
@pytest.mark.parametrize('cxx', [False, True])
902-
def test_outer_deriv(self, cxx):
898+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
899+
def test_outer_deriv(self, language):
903900
f, f_real, f_imag = self.setup_basic(np.complex64)
904901

905902
eq_re = Eq(f_real, Re(f).dx)
906903
eq_im = Eq(f_imag, Im(f).dx)
907904

908-
self.run_operator([eq_re, eq_im], cxx)
905+
self.run_operator([eq_re, eq_im], language)
909906

910907
assert np.all(np.isclose(f_real.data, 1.))
911908
assert np.all(np.isclose(f_imag.data, -1.))
912909

913-
@pytest.mark.parametrize('cxx', [False, True])
914-
def test_mul(self, cxx):
910+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
911+
def test_mul(self, language):
915912
grid = Grid(shape=(5,))
916913

917914
f = Function(name='f', grid=grid, dtype=np.complex64)
@@ -931,7 +928,7 @@ def test_mul(self, cxx):
931928
eq_fh_re = Eq(fh_re, Re(f*h))
932929
eq_fh_im = Eq(fh_im, Im(f*h))
933930

934-
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
931+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
935932

936933
assert np.all(np.isclose(fg_re.data, 2.))
937934
assert np.all(np.isclose(fg_im.data, 2.))

0 commit comments

Comments
 (0)