|
7 | 7 | from sympy import Expr, Number, Symbol |
8 | 8 | from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa |
9 | 9 | Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, |
10 | | - Min, Max, SubDomain) |
| 10 | + Min, Max, Re, Im, SubDomain) |
11 | 11 | from devito.finite_differences.differentiable import SafeInv, Weights, Mul |
12 | 12 | from devito.ir import Expression, FindNodes, ccode |
13 | 13 | from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa |
@@ -912,3 +912,123 @@ def test_print_div(): |
912 | 912 | b = SizeOf(np.int64) |
913 | 913 | cstr = ccode(a / b) |
914 | 914 | assert cstr == 'sizeof(int)/sizeof(long)' |
| 915 | + |
| 916 | + |
| 917 | +class TestComplexParts: |
| 918 | + # TODO: Add a cxx switchconfig |
| 919 | + def setup_basic(self, dtype): |
| 920 | + grid = Grid(shape=(5,), extent=(4.,)) |
| 921 | + f = Function(name='f', grid=grid, dtype=dtype) |
| 922 | + f.data_with_halo[:] = np.arange(7) + 1j*np.arange(7, 14)[::-1] |
| 923 | + |
| 924 | + f_real = Function(name='f_real', grid=grid) |
| 925 | + f_imag = Function(name='f_imag', grid=grid) |
| 926 | + return f, f_real, f_imag |
| 927 | + |
| 928 | + def run_operator(self, eqs, cxx): |
| 929 | + if cxx: |
| 930 | + with switchconfig(language='CXX'): |
| 931 | + Operator(eqs)() |
| 932 | + else: |
| 933 | + Operator(eqs)() |
| 934 | + |
| 935 | + @pytest.mark.parametrize('cxx', [False, True]) |
| 936 | + def test_printing(self, cxx): |
| 937 | + f, f_real, f_imag = self.setup_basic(np.complex64) |
| 938 | + |
| 939 | + eq_re = Eq(f_real, Re(f)) |
| 940 | + eq_im = Eq(f_imag, Im(f)) |
| 941 | + |
| 942 | + if cxx: |
| 943 | + with switchconfig(language='CXX'): |
| 944 | + op = Operator([eq_re, eq_im]) |
| 945 | + assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode) |
| 946 | + assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode) |
| 947 | + |
| 948 | + else: |
| 949 | + op = Operator([eq_re, eq_im]) |
| 950 | + assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode) |
| 951 | + assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode) |
| 952 | + |
| 953 | + @pytest.mark.parametrize('cxx', [False, True]) |
| 954 | + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
| 955 | + def test_trivial(self, cxx, dtype): |
| 956 | + f, f_real, f_imag = self.setup_basic(dtype) |
| 957 | + |
| 958 | + eq_re = Eq(f_real, Re(f+1.)) |
| 959 | + eq_im = Eq(f_imag, Im(f+1.)) |
| 960 | + |
| 961 | + self.run_operator([eq_re, eq_im], cxx) |
| 962 | + |
| 963 | + rcheck = np.array([2., 3., 4., 5., 6.]) |
| 964 | + icheck = np.array([12., 11., 10., 9., 8.]) |
| 965 | + assert np.all(np.isclose(f_real.data, rcheck)) |
| 966 | + assert np.all(np.isclose(f_imag.data, icheck)) |
| 967 | + |
| 968 | + @pytest.mark.parametrize('cxx', [False, True]) |
| 969 | + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
| 970 | + def test_trivial_imag(self, cxx, dtype): |
| 971 | + f, f_real, f_imag = self.setup_basic(dtype) |
| 972 | + |
| 973 | + eq_re = Eq(f_real, Re(f+1j)) |
| 974 | + eq_im = Eq(f_imag, Im(f+1j)) |
| 975 | + |
| 976 | + self.run_operator([eq_re, eq_im], cxx) |
| 977 | + |
| 978 | + rcheck = np.array([1., 2., 3., 4., 5.]) |
| 979 | + icheck = np.array([13., 12., 11., 10., 9.]) |
| 980 | + assert np.all(np.isclose(f_real.data, rcheck)) |
| 981 | + assert np.all(np.isclose(f_imag.data, icheck)) |
| 982 | + |
| 983 | + @pytest.mark.parametrize('cxx', [False, True]) |
| 984 | + def test_deriv(self, cxx): |
| 985 | + f, f_real, f_imag = self.setup_basic(np.complex64) |
| 986 | + |
| 987 | + eq_re = Eq(f_real, Re(f.dx)) |
| 988 | + eq_im = Eq(f_imag, Im(f.dx)) |
| 989 | + |
| 990 | + self.run_operator([eq_re, eq_im], cxx) |
| 991 | + |
| 992 | + assert np.all(np.isclose(f_real.data, 1.)) |
| 993 | + assert np.all(np.isclose(f_imag.data, -1.)) |
| 994 | + |
| 995 | + @pytest.mark.parametrize('cxx', [False, True]) |
| 996 | + def test_outer_deriv(self, cxx): |
| 997 | + f, f_real, f_imag = self.setup_basic(np.complex64) |
| 998 | + |
| 999 | + eq_re = Eq(f_real, Re(f).dx) |
| 1000 | + eq_im = Eq(f_imag, Im(f).dx) |
| 1001 | + |
| 1002 | + self.run_operator([eq_re, eq_im], cxx) |
| 1003 | + |
| 1004 | + assert np.all(np.isclose(f_real.data, 1.)) |
| 1005 | + assert np.all(np.isclose(f_imag.data, -1.)) |
| 1006 | + |
| 1007 | + @pytest.mark.parametrize('cxx', [False, True]) |
| 1008 | + def test_mul(self, cxx): |
| 1009 | + grid = Grid(shape=(5,)) |
| 1010 | + |
| 1011 | + f = Function(name='f', grid=grid, dtype=np.complex64) |
| 1012 | + g = Function(name='g', grid=grid) |
| 1013 | + h = Function(name='h', grid=grid, dtype=np.complex64) |
| 1014 | + f.data[:] = 1 + 1j |
| 1015 | + g.data[:] = 2 |
| 1016 | + h.data[:] = 2j |
| 1017 | + |
| 1018 | + fg_re = Function(name='fg_re', grid=grid) |
| 1019 | + fg_im = Function(name='fg_im', grid=grid) |
| 1020 | + fh_re = Function(name='fh_re', grid=grid) |
| 1021 | + fh_im = Function(name='fh_im', grid=grid) |
| 1022 | + |
| 1023 | + eq_fg_re = Eq(fg_re, Re(f*g)) |
| 1024 | + eq_fg_im = Eq(fg_im, Im(f*g)) |
| 1025 | + eq_fh_re = Eq(fh_re, Re(f*h)) |
| 1026 | + eq_fh_im = Eq(fh_im, Im(f*h)) |
| 1027 | + |
| 1028 | + self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx) |
| 1029 | + |
| 1030 | + assert np.all(np.isclose(fg_re.data, 2.)) |
| 1031 | + assert np.all(np.isclose(fg_im.data, 2.)) |
| 1032 | + |
| 1033 | + assert np.all(np.isclose(fh_re.data, -2.)) |
| 1034 | + assert np.all(np.isclose(fh_im.data, 2.)) |
0 commit comments