|
7 | 7 | from sympy import Expr, Symbol |
8 | 8 | from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa |
9 | 9 | Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, |
10 | | - Min, Max) |
| 10 | + Min, Max, Re, Im) |
11 | 11 | from devito.finite_differences.differentiable import SafeInv, Weights |
12 | 12 | from devito.ir import Expression, FindNodes, ccode |
13 | 13 | from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa |
@@ -818,3 +818,108 @@ def test_assumptions(self, op, expr, assumptions, expected): |
818 | 818 | assumptions = eval(assumptions) |
819 | 819 | expected = eval(expected) |
820 | 820 | assert evalrel(op, eqn, assumptions) == expected |
| 821 | + |
| 822 | + |
| 823 | +class TestComplexParts: |
| 824 | + def setup_basic(self, dtype): |
| 825 | + grid = Grid(shape=(5,), extent=(4.,)) |
| 826 | + f = Function(name='f', grid=grid, dtype=dtype) |
| 827 | + f.data_with_halo[:] = np.arange(7) + 1j*np.arange(7, 14)[::-1] |
| 828 | + |
| 829 | + f_real = Function(name='f_real', grid=grid) |
| 830 | + f_imag = Function(name='f_imag', grid=grid) |
| 831 | + return f, f_real, f_imag |
| 832 | + |
| 833 | + def test_printing(self): |
| 834 | + f, f_real, f_imag = self.setup_basic(np.complex64) |
| 835 | + |
| 836 | + eq_re = Eq(f_real, Re(f)) |
| 837 | + eq_im = Eq(f_imag, Im(f)) |
| 838 | + |
| 839 | + op = Operator([eq_re, eq_im]) |
| 840 | + |
| 841 | + assert ("#define Im(x) (_Generic((x), float _Complex : (float *) &(x), " |
| 842 | + "double _Complex : (double *) &(x))[1])") in str(op.ccode) |
| 843 | + assert ("#define Re(x) (_Generic((x), float _Complex : (float *) &(x), " |
| 844 | + "double _Complex : (double *) &(x))[0])") in str(op.ccode) |
| 845 | + |
| 846 | + assert "f_real[x + 1] = Re(f[x + 1])" in str(op.ccode) |
| 847 | + assert "f_imag[x + 1] = Im(f[x + 1])" in str(op.ccode) |
| 848 | + |
| 849 | + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
| 850 | + def test_trivial(self, dtype): |
| 851 | + f, f_real, f_imag = self.setup_basic(dtype) |
| 852 | + |
| 853 | + eq_re = Eq(f_real, Re(f+1)) |
| 854 | + eq_im = Eq(f_imag, Im(f+1)) |
| 855 | + |
| 856 | + Operator([eq_re, eq_im])() |
| 857 | + |
| 858 | + rcheck = np.array([2., 3., 4., 5., 6.]) |
| 859 | + icheck = np.array([12., 11., 10., 9., 8.]) |
| 860 | + assert np.all(np.isclose(f_real.data, rcheck)) |
| 861 | + assert np.all(np.isclose(f_imag.data, icheck)) |
| 862 | + |
| 863 | + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
| 864 | + def test_trivial_imag(self, dtype): |
| 865 | + f, f_real, f_imag = self.setup_basic(dtype) |
| 866 | + |
| 867 | + eq_re = Eq(f_real, Re(f+1j)) |
| 868 | + eq_im = Eq(f_imag, Im(f+1j)) |
| 869 | + |
| 870 | + Operator([eq_re, eq_im])() |
| 871 | + |
| 872 | + rcheck = np.array([1., 2., 3., 4., 5.]) |
| 873 | + icheck = np.array([13., 12., 11., 10., 9.]) |
| 874 | + assert np.all(np.isclose(f_real.data, rcheck)) |
| 875 | + assert np.all(np.isclose(f_imag.data, icheck)) |
| 876 | + |
| 877 | + def test_deriv(self): |
| 878 | + f, f_real, f_imag = self.setup_basic(np.complex64) |
| 879 | + |
| 880 | + eq_re = Eq(f_real, Re(f.dx)) |
| 881 | + eq_im = Eq(f_imag, Im(f.dx)) |
| 882 | + |
| 883 | + Operator([eq_re, eq_im])() |
| 884 | + |
| 885 | + assert np.all(np.isclose(f_real.data, 1.)) |
| 886 | + assert np.all(np.isclose(f_imag.data, -1.)) |
| 887 | + |
| 888 | + def test_outer_deriv(self): |
| 889 | + f, f_real, f_imag = self.setup_basic(np.complex64) |
| 890 | + |
| 891 | + eq_re = Eq(f_real, Re(f).dx) |
| 892 | + eq_im = Eq(f_imag, Im(f).dx) |
| 893 | + |
| 894 | + Operator([eq_re, eq_im])() |
| 895 | + |
| 896 | + assert np.all(np.isclose(f_real.data, 1.)) |
| 897 | + assert np.all(np.isclose(f_imag.data, -1.)) |
| 898 | + |
| 899 | + def test_mul(self): |
| 900 | + grid = Grid(shape=(5,)) |
| 901 | + |
| 902 | + f = Function(name='f', grid=grid, dtype=np.complex64) |
| 903 | + g = Function(name='g', grid=grid) |
| 904 | + h = Function(name='h', grid=grid, dtype=np.complex64) |
| 905 | + f.data[:] = 1 + 1j |
| 906 | + g.data[:] = 2 |
| 907 | + h.data[:] = 2j |
| 908 | + |
| 909 | + fg_re = Function(name='fg_re', grid=grid) |
| 910 | + fg_im = Function(name='fg_im', grid=grid) |
| 911 | + fh_re = Function(name='fh_re', grid=grid) |
| 912 | + fh_im = Function(name='fh_im', grid=grid) |
| 913 | + |
| 914 | + eq_fg_re = Eq(fg_re, Re(f*g)) |
| 915 | + eq_fg_im = Eq(fg_im, Im(f*g)) |
| 916 | + eq_fh_re = Eq(fh_re, Re(f*h)) |
| 917 | + eq_fh_im = Eq(fh_im, Im(f*h)) |
| 918 | + |
| 919 | + Operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im])() |
| 920 | + |
| 921 | + assert np.all(np.isclose(fg_re.data, 2.)) |
| 922 | + assert np.all(np.isclose(fg_im.data, 2.)) |
| 923 | + |
| 924 | + assert np.all(np.isclose(fh_re.data, -2.)) |
| 925 | + assert np.all(np.isclose(fh_im.data, 2.)) |
0 commit comments