77from sympy import Expr , Symbol
88from devito import (Constant , Dimension , Grid , Function , solve , TimeFunction , Eq , # noqa
99 Operator , SubDimension , norm , Le , Ge , Gt , Lt , Abs , sin , cos ,
10- Min , Max , Re , Im )
10+ Min , Max , Re , Im , switchconfig )
1111from devito .finite_differences .differentiable import SafeInv , Weights
1212from devito .ir import Expression , FindNodes , ccode
1313from devito .symbolics import (retrieve_functions , retrieve_indexed , evalrel , # noqa
@@ -821,6 +821,7 @@ def test_assumptions(self, op, expr, assumptions, expected):
821821
822822
823823class TestComplexParts :
824+ # TODO: Add a cxx switchconfig
824825 def setup_basic (self , dtype ):
825826 grid = Grid (shape = (5 ,), extent = (4. ,))
826827 f = Function (name = 'f' , grid = grid , dtype = dtype )
@@ -830,73 +831,87 @@ def setup_basic(self, dtype):
830831 f_imag = Function (name = 'f_imag' , grid = grid )
831832 return f , f_real , f_imag
832833
833- def test_printing (self ):
834+ def run_operator (self , eqs , cxx ):
835+ if cxx :
836+ with switchconfig (language = 'CXX' ):
837+ Operator (eqs )()
838+ else :
839+ Operator (eqs )()
840+
841+ @pytest .mark .parametrize ('cxx' , [False , True ])
842+ def test_printing (self , cxx ):
834843 f , f_real , f_imag = self .setup_basic (np .complex64 )
835844
836845 eq_re = Eq (f_real , Re (f ))
837846 eq_im = Eq (f_imag , Im (f ))
838847
839- op = Operator ([eq_re , eq_im ])
840-
841- assert ("#define Im(x) (_Generic((x), float _Complex : (float *) &(x), "
842- "double _Complex : (double *) &(x))[1])" ) in str (op .ccode )
843- assert ("#define Re(x) (_Generic((x), float _Complex : (float *) &(x), "
844- "double _Complex : (double *) &(x))[0])" ) in str (op .ccode )
848+ if cxx :
849+ with switchconfig (language = 'CXX' ):
850+ op = Operator ([eq_re , eq_im ])
851+ assert "f_real[x + 1] = std::real(f[x + 1])" in str (op .ccode )
852+ assert "f_imag[x + 1] = std::imag(f[x + 1])" in str (op .ccode )
845853
846- assert "f_real[x + 1] = Re(f[x + 1])" in str (op .ccode )
847- assert "f_imag[x + 1] = Im(f[x + 1])" in str (op .ccode )
854+ else :
855+ op = Operator ([eq_re , eq_im ])
856+ assert "f_real[x + 1] = crealf(f[x + 1])" in str (op .ccode )
857+ assert "f_imag[x + 1] = cimagf(f[x + 1])" in str (op .ccode )
848858
859+ @pytest .mark .parametrize ('cxx' , [False , True ])
849860 @pytest .mark .parametrize ('dtype' , [np .complex64 , np .complex128 ])
850- def test_trivial (self , dtype ):
861+ def test_trivial (self , cxx , dtype ):
851862 f , f_real , f_imag = self .setup_basic (dtype )
852863
853- eq_re = Eq (f_real , Re (f + 1 ))
854- eq_im = Eq (f_imag , Im (f + 1 ))
864+ eq_re = Eq (f_real , Re (f + 1. ))
865+ eq_im = Eq (f_imag , Im (f + 1. ))
855866
856- Operator ([eq_re , eq_im ])( )
867+ self . run_operator ([eq_re , eq_im ], cxx )
857868
858869 rcheck = np .array ([2. , 3. , 4. , 5. , 6. ])
859870 icheck = np .array ([12. , 11. , 10. , 9. , 8. ])
860871 assert np .all (np .isclose (f_real .data , rcheck ))
861872 assert np .all (np .isclose (f_imag .data , icheck ))
862873
874+ @pytest .mark .parametrize ('cxx' , [False , True ])
863875 @pytest .mark .parametrize ('dtype' , [np .complex64 , np .complex128 ])
864- def test_trivial_imag (self , dtype ):
876+ def test_trivial_imag (self , cxx , dtype ):
865877 f , f_real , f_imag = self .setup_basic (dtype )
866878
867879 eq_re = Eq (f_real , Re (f + 1j ))
868880 eq_im = Eq (f_imag , Im (f + 1j ))
869881
870- Operator ([eq_re , eq_im ])( )
882+ self . run_operator ([eq_re , eq_im ], cxx )
871883
872884 rcheck = np .array ([1. , 2. , 3. , 4. , 5. ])
873885 icheck = np .array ([13. , 12. , 11. , 10. , 9. ])
874886 assert np .all (np .isclose (f_real .data , rcheck ))
875887 assert np .all (np .isclose (f_imag .data , icheck ))
876888
877- def test_deriv (self ):
889+ @pytest .mark .parametrize ('cxx' , [False , True ])
890+ def test_deriv (self , cxx ):
878891 f , f_real , f_imag = self .setup_basic (np .complex64 )
879892
880893 eq_re = Eq (f_real , Re (f .dx ))
881894 eq_im = Eq (f_imag , Im (f .dx ))
882895
883- Operator ([eq_re , eq_im ])( )
896+ self . run_operator ([eq_re , eq_im ], cxx )
884897
885898 assert np .all (np .isclose (f_real .data , 1. ))
886899 assert np .all (np .isclose (f_imag .data , - 1. ))
887900
888- def test_outer_deriv (self ):
901+ @pytest .mark .parametrize ('cxx' , [False , True ])
902+ def test_outer_deriv (self , cxx ):
889903 f , f_real , f_imag = self .setup_basic (np .complex64 )
890904
891905 eq_re = Eq (f_real , Re (f ).dx )
892906 eq_im = Eq (f_imag , Im (f ).dx )
893907
894- Operator ([eq_re , eq_im ])( )
908+ self . run_operator ([eq_re , eq_im ], cxx )
895909
896910 assert np .all (np .isclose (f_real .data , 1. ))
897911 assert np .all (np .isclose (f_imag .data , - 1. ))
898912
899- def test_mul (self ):
913+ @pytest .mark .parametrize ('cxx' , [False , True ])
914+ def test_mul (self , cxx ):
900915 grid = Grid (shape = (5 ,))
901916
902917 f = Function (name = 'f' , grid = grid , dtype = np .complex64 )
@@ -916,7 +931,7 @@ def test_mul(self):
916931 eq_fh_re = Eq (fh_re , Re (f * h ))
917932 eq_fh_im = Eq (fh_im , Im (f * h ))
918933
919- Operator ([eq_fg_re , eq_fg_im , eq_fh_re , eq_fh_im ])( )
934+ self . run_operator ([eq_fg_re , eq_fg_im , eq_fh_re , eq_fh_im ], cxx )
920935
921936 assert np .all (np .isclose (fg_re .data , 2. ))
922937 assert np .all (np .isclose (fg_im .data , 2. ))
0 commit comments