Skip to content

Commit 9f1d076

Browse files
committed
dsl: Update Re and Im to Real and Imag, add Conj
1 parent f6d719b commit 9f1d076

4 files changed

Lines changed: 49 additions & 25 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights', 'Re', 'Im']
25+
'Weights', 'Real', 'Imag', 'Conj']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -645,7 +645,7 @@ def __str__(self):
645645

646646

647647
class ComplexPart(Differentiable, sympy.core.function.Application):
648-
"""Abstract class for `Re` or `Im` of an expression"""
648+
"""Abstract class for `Real`, `Imag`, or `Conj` of an expression"""
649649

650650
def __new__(cls, *args, **kwargs):
651651
if len(args) != 1:
@@ -667,16 +667,21 @@ def __str__(self):
667667
__repr__ = __str__
668668

669669

670-
class Re(ComplexPart):
670+
class Real(ComplexPart):
671671
"""Get the real part of an expression"""
672672
pass
673673

674674

675-
class Im(ComplexPart):
675+
class Imag(ComplexPart):
676676
"""Get the imaginary part of an expression"""
677677
pass
678678

679679

680+
class Conj(ComplexPart):
681+
"""Get the complex conjugate of an expression"""
682+
pass
683+
684+
680685
class IndexSum(sympy.Expr, Evaluable):
681686

682687
"""

devito/passes/iet/languages/C.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,14 @@ def _print_ListInitializer(self, expr):
6868
else:
6969
return li
7070

71-
def _print_Re(self, expr):
71+
def _print_Real(self, expr):
7272
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
7373
f'({self._print(expr.args[0])})')
7474

75-
def _print_Im(self, expr):
75+
def _print_Imag(self, expr):
7676
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
7777
f'({self._print(expr.args[0])})')
78+
79+
def _print_Conj(self, expr):
80+
return (f'conj{self.func_literal(expr).lower()}'
81+
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,15 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
107107
def _print_ImaginaryUnit(self, expr):
108108
return f'1i{self.prec_literal(expr).lower()}'
109109

110-
def _print_Re(self, expr):
110+
def _print_Real(self, expr):
111111
return f'{self._ns}real({self._print(expr.args[0])})'
112112

113-
def _print_Im(self, expr):
113+
def _print_Imag(self, expr):
114114
return f'{self._ns}imag({self._print(expr.args[0])})'
115115

116+
def _print_Conj(self, expr):
117+
return f'{self._ns}conj({self._print(expr.args[0])})'
118+
116119
def _print_Cast(self, expr):
117120
# The CXX recommended way to cast is to use static_cast
118121
tstr = self._print(expr._C_ctype)

tests/test_symbolics.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max, Re, Im, SubDomain)
10+
Min, Max, Real, Imag, Conj, SubDomain)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -931,15 +931,15 @@ def run_operator(self, eqs, language):
931931
def test_devito_print(self):
932932
f, _, _ = self.setup_basic(np.complex64)
933933

934-
assert str(Re(f)) == 'Re(f(x))'
935-
assert str(Im(f)) == 'Im(f(x))'
934+
assert str(Real(f)) == 'Real(f(x))'
935+
assert str(Imag(f)) == 'Imag(f(x))'
936936

937937
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
938938
def test_printing(self, language):
939939
f, f_real, f_imag = self.setup_basic(np.complex64)
940940

941-
eq_re = Eq(f_real, Re(f))
942-
eq_im = Eq(f_imag, Im(f))
941+
eq_re = Eq(f_real, Real(f))
942+
eq_im = Eq(f_imag, Imag(f))
943943

944944
with switchconfig(language=language):
945945
op = Operator([eq_re, eq_im])
@@ -957,8 +957,8 @@ def test_printing(self, language):
957957
def test_trivial(self, language, dtype):
958958
f, f_real, f_imag = self.setup_basic(dtype)
959959

960-
eq_re = Eq(f_real, Re(f+1.))
961-
eq_im = Eq(f_imag, Im(f+1.))
960+
eq_re = Eq(f_real, Real(f+1.))
961+
eq_im = Eq(f_imag, Imag(f+1.))
962962

963963
self.run_operator([eq_re, eq_im], language)
964964

@@ -972,8 +972,8 @@ def test_trivial(self, language, dtype):
972972
def test_trivial_imag(self, language, dtype):
973973
f, f_real, f_imag = self.setup_basic(dtype)
974974

975-
eq_re = Eq(f_real, Re(f+1j))
976-
eq_im = Eq(f_imag, Im(f+1j))
975+
eq_re = Eq(f_real, Real(f+1j))
976+
eq_im = Eq(f_imag, Imag(f+1j))
977977

978978
self.run_operator([eq_re, eq_im], language)
979979

@@ -986,8 +986,8 @@ def test_trivial_imag(self, language, dtype):
986986
def test_deriv(self, language):
987987
f, f_real, f_imag = self.setup_basic(np.complex64)
988988

989-
eq_re = Eq(f_real, Re(f.dx))
990-
eq_im = Eq(f_imag, Im(f.dx))
989+
eq_re = Eq(f_real, Real(f.dx))
990+
eq_im = Eq(f_imag, Imag(f.dx))
991991

992992
self.run_operator([eq_re, eq_im], language)
993993

@@ -998,8 +998,8 @@ def test_deriv(self, language):
998998
def test_outer_deriv(self, language):
999999
f, f_real, f_imag = self.setup_basic(np.complex64)
10001000

1001-
eq_re = Eq(f_real, Re(f).dx)
1002-
eq_im = Eq(f_imag, Im(f).dx)
1001+
eq_re = Eq(f_real, Real(f).dx)
1002+
eq_im = Eq(f_imag, Imag(f).dx)
10031003

10041004
self.run_operator([eq_re, eq_im], language)
10051005

@@ -1022,10 +1022,10 @@ def test_mul(self, language):
10221022
fh_re = Function(name='fh_re', grid=grid)
10231023
fh_im = Function(name='fh_im', grid=grid)
10241024

1025-
eq_fg_re = Eq(fg_re, Re(f*g))
1026-
eq_fg_im = Eq(fg_im, Im(f*g))
1027-
eq_fh_re = Eq(fh_re, Re(f*h))
1028-
eq_fh_im = Eq(fh_im, Im(f*h))
1025+
eq_fg_re = Eq(fg_re, Real(f*g))
1026+
eq_fg_im = Eq(fg_im, Imag(f*g))
1027+
eq_fh_re = Eq(fh_re, Real(f*h))
1028+
eq_fh_im = Eq(fh_im, Imag(f*h))
10291029

10301030
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], language)
10311031

@@ -1034,3 +1034,15 @@ def test_mul(self, language):
10341034

10351035
assert np.all(np.isclose(fh_re.data, -2.))
10361036
assert np.all(np.isclose(fh_im.data, 2.))
1037+
1038+
@pytest.mark.parametrize('language', ['C', 'CXX', 'CXXopenmp'])
1039+
def test_conj(self, language):
1040+
grid = Grid(shape=(5,))
1041+
f = Function(name='f', grid=grid, dtype=np.complex64)
1042+
g = Function(name='g', grid=grid, dtype=np.complex64)
1043+
1044+
f.data[:] = np.arange(5) + 1j*np.arange(5)[::-1]
1045+
1046+
self.run_operator([Eq(g, Conj(f))], language)
1047+
1048+
assert np.all(np.isclose(g.data, np.conj(f.data)))

0 commit comments

Comments
 (0)