Skip to content

Commit 26159d0

Browse files
committed
dsl: Update Re and Im to Real and Imag, add Conj
1 parent 17438ac commit 26159d0

4 files changed

Lines changed: 49 additions & 25 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights', 'Re', 'Im']
25+
'Weights', 'Real', 'Imag', 'Conj']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -646,7 +646,7 @@ def __str__(self):
646646

647647

648648
class ComplexPart(Differentiable, sympy.core.function.Application):
649-
"""Abstract class for `Re` or `Im` of an expression"""
649+
"""Abstract class for `Real`, `Imag`, or `Conj` of an expression"""
650650

651651
def __new__(cls, *args, **kwargs):
652652
if len(args) != 1:
@@ -668,16 +668,21 @@ def __str__(self):
668668
__repr__ = __str__
669669

670670

671-
class Re(ComplexPart):
671+
class Real(ComplexPart):
672672
"""Get the real part of an expression"""
673673
pass
674674

675675

676-
class Im(ComplexPart):
676+
class Imag(ComplexPart):
677677
"""Get the imaginary part of an expression"""
678678
pass
679679

680680

681+
class Conj(ComplexPart):
682+
"""Get the complex conjugate of an expression"""
683+
pass
684+
685+
681686
class IndexSum(sympy.Expr, Evaluable):
682687

683688
"""

devito/passes/iet/languages/C.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,14 @@ class CPrinter(BasePrinter, C99CodePrinter):
5656
def _print_ImaginaryUnit(self, expr):
5757
return '_Complex_I'
5858

59-
def _print_Re(self, expr):
59+
def _print_Real(self, expr):
6060
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
6161
f'({self._print(expr.args[0])})')
6262

63-
def _print_Im(self, expr):
63+
def _print_Imag(self, expr):
6464
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
6565
f'({self._print(expr.args[0])})')
66+
67+
def _print_Conj(self, expr):
68+
return (f'conj{self.func_literal(expr).lower()}'
69+
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,15 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
104104
def _print_ImaginaryUnit(self, expr):
105105
return f'1i{self.prec_literal(expr).lower()}'
106106

107-
def _print_Re(self, expr):
107+
def _print_Real(self, expr):
108108
return f'{self._ns}real({self._print(expr.args[0])})'
109109

110-
def _print_Im(self, expr):
110+
def _print_Imag(self, expr):
111111
return f'{self._ns}imag({self._print(expr.args[0])})'
112112

113+
def _print_Conj(self, expr):
114+
return f'{self._ns}conj({self._print(expr.args[0])})'
115+
113116
def _print_Cast(self, expr):
114117
# The CXX recommended way to cast is to use static_cast
115118
tstr = self._print(expr._C_ctype)

tests/test_symbolics.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max, Re, Im, switchconfig)
10+
Min, Max, Real, Imag, Conj, switchconfig)
1111
from devito.finite_differences.differentiable import SafeInv, Weights
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -837,15 +837,15 @@ def run_operator(self, eqs, language):
837837
def test_devito_print(self):
838838
f, _, _ = self.setup_basic(np.complex64)
839839

840-
assert str(Re(f)) == 'Re(f(x))'
841-
assert str(Im(f)) == 'Im(f(x))'
840+
assert str(Real(f)) == 'Real(f(x))'
841+
assert str(Imag(f)) == 'Imag(f(x))'
842842

843843
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
844844
def test_printing(self, language):
845845
f, f_real, f_imag = self.setup_basic(np.complex64)
846846

847-
eq_re = Eq(f_real, Re(f))
848-
eq_im = Eq(f_imag, Im(f))
847+
eq_re = Eq(f_real, Real(f))
848+
eq_im = Eq(f_imag, Imag(f))
849849

850850
with switchconfig(language=language):
851851
op = Operator([eq_re, eq_im])
@@ -863,8 +863,8 @@ def test_printing(self, language):
863863
def test_trivial(self, language, dtype):
864864
f, f_real, f_imag = self.setup_basic(dtype)
865865

866-
eq_re = Eq(f_real, Re(f+1.))
867-
eq_im = Eq(f_imag, Im(f+1.))
866+
eq_re = Eq(f_real, Real(f+1.))
867+
eq_im = Eq(f_imag, Imag(f+1.))
868868

869869
self.run_operator([eq_re, eq_im], language)
870870

@@ -878,8 +878,8 @@ def test_trivial(self, language, dtype):
878878
def test_trivial_imag(self, language, dtype):
879879
f, f_real, f_imag = self.setup_basic(dtype)
880880

881-
eq_re = Eq(f_real, Re(f+1j))
882-
eq_im = Eq(f_imag, Im(f+1j))
881+
eq_re = Eq(f_real, Real(f+1j))
882+
eq_im = Eq(f_imag, Imag(f+1j))
883883

884884
self.run_operator([eq_re, eq_im], language)
885885

@@ -892,8 +892,8 @@ def test_trivial_imag(self, language, dtype):
892892
def test_deriv(self, language):
893893
f, f_real, f_imag = self.setup_basic(np.complex64)
894894

895-
eq_re = Eq(f_real, Re(f.dx))
896-
eq_im = Eq(f_imag, Im(f.dx))
895+
eq_re = Eq(f_real, Real(f.dx))
896+
eq_im = Eq(f_imag, Imag(f.dx))
897897

898898
self.run_operator([eq_re, eq_im], language)
899899

@@ -904,8 +904,8 @@ def test_deriv(self, language):
904904
def test_outer_deriv(self, language):
905905
f, f_real, f_imag = self.setup_basic(np.complex64)
906906

907-
eq_re = Eq(f_real, Re(f).dx)
908-
eq_im = Eq(f_imag, Im(f).dx)
907+
eq_re = Eq(f_real, Real(f).dx)
908+
eq_im = Eq(f_imag, Imag(f).dx)
909909

910910
self.run_operator([eq_re, eq_im], language)
911911

@@ -928,10 +928,10 @@ def test_mul(self, language):
928928
fh_re = Function(name='fh_re', grid=grid)
929929
fh_im = Function(name='fh_im', grid=grid)
930930

931-
eq_fg_re = Eq(fg_re, Re(f*g))
932-
eq_fg_im = Eq(fg_im, Im(f*g))
933-
eq_fh_re = Eq(fh_re, Re(f*h))
934-
eq_fh_im = Eq(fh_im, Im(f*h))
931+
eq_fg_re = Eq(fg_re, Real(f*g))
932+
eq_fg_im = Eq(fg_im, Imag(f*g))
933+
eq_fh_re = Eq(fh_re, Real(f*h))
934+
eq_fh_im = Eq(fh_im, Imag(f*h))
935935

936936
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
937937

@@ -940,3 +940,15 @@ def test_mul(self, language):
940940

941941
assert np.all(np.isclose(fh_re.data, -2.))
942942
assert np.all(np.isclose(fh_im.data, 2.))
943+
944+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
945+
def test_conj(self, language):
946+
grid = Grid(shape=(5,))
947+
f = Function(name='f', grid=grid, dtype=np.complex64)
948+
g = Function(name='g', grid=grid, dtype=np.complex64)
949+
950+
f.data[:] = np.arange(5) + 1j*np.arange(5)[::-1]
951+
952+
self.run_operator([Eq(g, Conj(f))], language)
953+
954+
assert np.all(np.isclose(g.data, np.conj(f.data)))

0 commit comments

Comments
 (0)